Geometric Deep Learning Extension Library for TensorFlow
Project description
GraphGallery
GraphGallery is a gallery of state-of-the-arts graph neural networks for TensorFlow 2.x.
This repo aims to achieve 4 goals:
- Similar or higher performance
- Faster training and testing
- Simple and convenient to use, high scalability
- Easy to read source codes
Installation
pip install -U graphgallery
Implementations
In detail, the following methods are currently implemented:
Semi-supervised models
General
- ChebyNet from Michaël Defferrard et al, 📝Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering, NIPS'16. [:octocat:Official Codes], [🌈 GraphGallery Example]
- GCN from Thomas N. Kipf et al, 📝Semi-Supervised Classification with Graph Convolutional Networks, ICLR'17. [:octocat:Official Codes], [🌈 GraphGallery Example]
- GraphSAGE from William L. Hamilton et al, 📝Inductive Representation Learning on Large Graphs, NIPS'17. [:octocat:Official Codes], [🌈 GraphGallery Example]
- FastGCN from Jie Chen et al, FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling, ICLR'18. [:octocat:Official Codes], [🌈 GraphGallery Example]
- LGCN from Hongyang Gao et al, 📝Large-Scale Learnable Graph Convolutional Networks, KDD'18. [:octocat:Official Codes], [🌈 GraphGallery Example]
- GAT from Petar Veličković et al, 📝Graph Attention Networks, ICLR'18.
[:octocat:Official Codes], [🌈 GraphGallery Example] - SGC from Felix Wu et al, 📝Simplifying Graph Convolutional Networks, ICML'19. [:octocat:Official Codes], [🌈 GraphGallery Example]
- GWNN from Bingbing Xu et al, 📝Graph Wavelet Neural Network, ICLR'19. [:octocat:Official Codes], [🌈 GraphGallery Example]
- GMNN from Meng Qu et al, 📝Graph Markov Neural Networks, ICML'19. [:octocat:Official Codes], [🌈 GraphGallery Example]
- ClusterGCN from Wei-Lin Chiang et al, 📝Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks, KDD'19. [:octocat:Official Codes], [🌈 GraphGallery Example]
- DAGNN from Meng Liu et al, 📝Towards Deeper Graph Neural Networks, KDD'20. [:octocat:Official Codes], [🌈 GraphGallery Example]
Defense models
- RobustGCN from Dingyuan Zhu et al, 📝Robust Graph Convolutional Networks Against Adversarial Attacks, KDD'19. [:octocat:Official Codes], [🌈 GraphGallery Example]
- SBVAT/OBVAT from Zhijie Deng et al, 📝Batch Virtual Adversarial Training for Graph Convolutional Networks, ICML'19. [:octocat:Official Codes]
Unsupervised models
- Deepwalk from Bryan Perozzi et al, 📝DeepWalk: Online Learning of Social Representations, KDD'14. [:octocat:Official Codes], [🌈 GraphGallery Example]
- Node2vec from Aditya Grover et al, 📝node2vec: Scalable Feature Learning for Networks, KDD'16. [:octocat:Official Codes], [🌈 GraphGallery Example]
Quick Start
Datasets
from graphgallery.data import Planetoid
# set `verbose=False` to avoid these printed tables
data = Planetoid('cora', verbose=False)
adj = data.adj
x = data.x
labels = data.labels
idx_train = data.idx_train
idx_val = data.idx_val
idx_test = data.idx_test
currently the supported datasets are:
>>> data.supported_datasets
>>> ('citeseer', 'cora', 'pubmed')
Example of GCN model
from graphgallery.nn.models import GCN
# adj is scipy sparse matrix, x is numpy array matrix
model = GCN(adj, x, labels, device='GPU', norm_x='l1', seed=123)
# build your GCN model with custom hyper-parameters
model.build()
# train your model. here idx_train and idx_val are numpy arrays
his = model.train(idx_train, idx_val, verbose=1, epochs=100)
# test your model
loss, accuracy = model.test(idx_test)
print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')
On Cora
dataset:
<Loss = 1.0161 Acc = 0.9500 Val_Loss = 1.4101 Val_Acc = 0.7740 >: 100%|██████████| 100/100 [00:01<00:00, 68.02it/s]
Test loss 1.4123, Test accuracy 81.20%
Customization
- Build your model you can use the following statement to build your model
# one hidden layer with hidden units 32 and activation function RELU
>>> model.build(hiddens=32, activations='relu')
# two hidden layer with hidden units 32, 64 and all activation functions are RELU
>>> model.build(hiddens=[32, 64], activations='relu')
# two hidden layer with hidden units 32, 64 and activation functions RELU and ELU
>>> model.build(hiddens=[32, 64], activations=['relu', 'elu'])
# other parameters like `dropouts` and `l2_norms` (if have) are the SAME.
- Train your model
# train with validation
>>> his = model.train(idx_train, idx_val, verbose=1, epochs=100)
# train without validation
# his = model.train(idx_train, verbose=1, epochs=100)
here his
is tensorflow Histoory
like instance (or itself).
- Test you model
loss, accuracy = model.test(idx_test)
print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')
- Display hyper-parameters
You can simply use model.show()
to show all your Hyper-parameters
.
Otherwise you can also use model.show('model')
or model.show('train')
to show your model parameters and training parameters.
NOTE: you should install texttable first.
Visualization
NOTE: you must install SciencePlots package for a better preview.
- Accuracy
import matplotlib.pyplot as plt
with plt.style.context(['science', 'no-latex']):
plt.plot(his.history['acc'])
plt.plot(his.history['val_acc'])
plt.legend(['Train Accuracy', 'Val Accuracy'])
plt.ylabel('Accuracy')
plt.xlabel('Epochs')
plt.autoscale(tight=True)
plt.show()
- Loss
import matplotlib.pyplot as plt
with plt.style.context(['science', 'no-latex']):
plt.plot(his.history['loss'])
plt.plot(his.history['val_loss'])
plt.legend(['Train Loss', 'Val Loss'])
plt.ylabel('Loss')
plt.xlabel('Epochs')
plt.autoscale(tight=True)
plt.show()
More Examples
Please refer to the examples directory.
TODO Lists
- Add Docstrings and Documentation
- Support for
graph Classification
andlink prediction
tasks - Support for Heterogeneous graphs
- Add PyTorch models support
Acknowledgement
This project is motivated by Pytorch Geometric, Tensorflow Geometric and Stellargraph, and the original implementations from the authors, thanks for their excellent works!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for graphgallery-0.1.11-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 302bb5763869731219c492dbd0b3542c5810f53e8158890967e4eaea0b9d06fd |
|
MD5 | 9e5f97664f9d2f03997f8c65ebf6a67c |
|
BLAKE2b-256 | 6f135113f89547801bc6e7436434343d86f0cb1b8a6b0140046561831f694983 |