Geometric Deep Learning Extension Library for TensorFlow
Project description
GraphGallery
A gallery of state-of-the-arts graph neural networks. Implemented with TensorFlow 2.x.
This repo aims to achieve 4 goals:
- Similar or higher performance
- Faster training and testing
- Simple and convenient to use, high scalability
- Easy to read source codes
Requirements
- python>=3.6
- tensorflow>=2.1 (2.1 is recommended)
- networkx==2.3
- scipy
- scikit_learn
- numpy
- numba
- gensim
Other packages (not necessary):
- metis==0.2a4 (required for
ClusterGCN
) - texttable
Installation
pip install -U graphgallery
Implementation
General models
- ChebyNet from Michaël Defferrard et al, 📝Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering, NIPS'16, :octocat:Codes
- GCN from Thomas N. Kipf et al, 📝Semi-Supervised Classification with Graph Convolutional Networks, ICLR'17, :octocat:Codes
- GraphSAGE from William L. Hamilton et al, 📝Inductive Representation Learning on Large Graphs, NIPS'17, :octocat:Codes
- FastGCN from Jie Chen et al, FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling , ICLR'18,:octocat:Codes
- LGCN from Hongyang Gao et al, 📝Large-Scale Learnable Graph Convolutional Networks, KDD'18, :octocat:Codes
- GAT from Petar Veličković et al, 📝Graph Attention Networks, ICLR'18, :octocat:Codes
- SGC from Felix Wu et al, 📝Simplifying Graph Convolutional Networks, ICML'19, :octocat:Codes
- GWNN from Bingbing Xu et al, 📝Graph Wavelet Neural Network, ICLR'19,:octocat:Codes
- GMNN from Meng Qu et al, 📝Graph Markov Neural Networks, ICML'19,:octocat:Codes
- ClusterGCN from Wei-Lin Chiang et al, 📝Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks, KDD'19, :octocat:Codes
- DAGNN from Meng Liu et al, 📝Towards Deeper Graph Neural Networks, KDD'20, :octocat:Codes
Defense models
- RobustGCN from Dingyuan Zhu et al, 📝Robust Graph Convolutional Networks Against Adversarial Attacks, KDD'19, :octocat:Codes
- SBVAT/OBVAT from Zhijie Deng et al, 📝Batch Virtual Adversarial Training for Graph Convolutional Networks, ICML'19, :octocat:Codes
Quick Start
Example of GCN model
from graphgallery.nn.models import GCN
# adj is scipy sparse matrix, x is numpy array matrix
model = GCN(adj, x, labels, device='GPU', norm_x='l1', seed=123)
# build your GCN model with custom hyper-parameters
model.build()
# train your model. here idx_train and idx_val are numpy arrays
his = model.train(idx_train, idx_val, verbose=1, epochs=100)
# test your model
loss, accuracy = model.test(idx_test)
print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')
On Cora
dataset:
loss 1.02, acc 95.00%, val_loss 1.41, val_acc 77.40%: 100%|██████████| 100/100 [00:02<00:00, 37.07it/s]
Test loss 1.4123, Test accuracy 81.20%
Build your model
you can use the following statement to build your model
# one hidden layer with hidden units 32 and activation function RELU
>>> model.build(hiddens=32, activations='relu')
# two hidden layer with hidden units 32, 64 and all activation functions are RELU
>>> model.build(hiddens=[32, 64], activations='relu')
# two hidden layer with hidden units 32, 64 and activation functions RELU and ELU
>>> model.build(hiddens=[32, 64], activations=['relu', 'elu'])
# other parameters like `dropouts` and `l2_norms` (if have) are the SAME.
Train or test your model
More details can be seen in the methods model.train and model.test
Hyper-parameters
you can simply use model.show()
to show all your Hyper-parameters
.
Otherwise you can also use model.show('model')
or model.show('train')
to show your model parameters and training parameters.
NOTE: you should install texttable first.
Visualization
- Accuracy
import matplotlib.pyplot as plt
plt.plot(his.history['acc'])
plt.plot(his.history['val_acc'])
plt.legend(['Accuracy', 'Val Accuracy'])
plt.xlabel('Epochs')
plt.show()
- Loss
import matplotlib.pyplot as plt
plt.plot(his.history['loss'])
plt.plot(his.history['val_loss'])
plt.legend(['Loss', 'Val Loss'])
plt.xlabel('Epochs')
plt.show()
Acknowledgement
This project is motivated by Pytorch Geometric, Tensorflow Geometric and Stellargraph, and the original implementations from the authors, thanks for their excellent works!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for graphgallery-0.1.8-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8f471b1833daf64058ea9f57b165987b077efe72ae3deafb2bbec0cdc6e6d371 |
|
MD5 | a36d24c8ba5a4a2ec826f8dc08e457ba |
|
BLAKE2b-256 | 65618b90ae1f0ccfbabeec0ffbbc056e5dd1a37f44dff0b69ce82b3c0fdedd1c |