Geometric Deep Learning Extension Library for TensorFlow
Project description
GraphGallery
A gallery of state-of-the-arts graph neural networks. Implemented with TensorFlow 2.x.
This repo aims to achieve 4 goals:
- Similar or higher performance
- Faster training and testing
- Simple and convenient to use, high scalability
- Easy to read source codes
Requirements
- python>=3.6
- tensorflow>=2.1 (2.1 is recommended)
- networkx==2.3
- scipy>=1.4.1
- scikit_learn>=0.22
- numpy>=1.17.4
- numba>=0.48
- gensim>=3.8.1
Other packages (not necessary):
- metis==0.2a4 (used for
ClusterGCN
) - texttable
Install
pip install -U graphgallery
Implementation
General models
- GCN from Semi-Supervised Classification with Graph Convolutional Networks 🌐Paper
- GAT from Graph Attention Networks 🌐Paper
- SGC from Simplifying Graph Convolutional Networks 🌐Paper
- GraphSAGE from Inductive Representation Learning on Large Graphs 🌐Paper
- GWNN from Graph Wavelet Neural Network 🌐Paper
- GMNN from Graph Markov Neural Networks 🌐Paper
- ChebyNet from Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering 🌐Paper
- ClusterGCN from Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks 🌐Paper
- FastGCN from FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling 🌐Paper
- LGCN from Large-Scale Learnable Graph Convolutional Networks 🌐Paper
Defense models
- RGCN from Robust Graph Convolutional Networks Against Adversarial Attacks 🌐Paper
- SBVAT/OBVAT from Batch Virtual Adversarial Training for Graph Convolutional Networks 🌐Paper
Other models
- GCN_MIX: Mixture of GCN and MLP
- GCNF: GCN + feature
- DenseGCN: Dense version of GCN
- EdgeGCN: GCN using message passing framework
- MedianSAGE: GraphSAGE using
Median
aggregation
Quick Start
Train a GCN model
from graphgallery.nn.models import GCN
# adj is scipy sparse matrix, x is numpy array matrix
model = GCN(adj, x, labels, device='GPU', seed=123)
# build your GCN model with custom hyper-parameters
model.build()
# train your model. here idx_train and idx_val are numpy arrays
his = model.train(idx_train, idx_val, verbose=1, epochs=100)
# test your model
loss, accuracy = model.test(idx_test)
print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')
On Cora
dataset:
loss 1.02, acc 95.00%, val_loss 1.41, val_acc 77.40%: 100%|██████████| 100/100 [00:02<00:00, 37.07it/s]
Test loss 1.4123, Test accuracy 81.20%
Visualization
- Accuracy
import matplotlib.pyplot as plt
plt.plot(his.history['acc'])
plt.plot(his.history['val_acc'])
plt.legend(['Accuracy', 'Val Accuracy'])
plt.xlabel('Epochs')
- Loss
import matplotlib.pyplot as plt
plt.plot(his.history['loss'])
plt.plot(his.history['val_loss'])
plt.legend(['Loss', 'Val Loss'])
plt.xlabel('Epochs')
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
graphgallery-0.1.5.tar.gz
(50.5 kB
view hashes)
Built Distribution
graphgallery-0.1.5-py3-none-any.whl
(121.7 kB
view hashes)
Close
Hashes for graphgallery-0.1.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3d8a5a9dcf4bc33434f051374be17838ea06feeaa3e0c69f96979ab5805820bd |
|
MD5 | 45054b8bd7c65cf1dbcef46ea6f7de9a |
|
BLAKE2b-256 | e1e9651f9516a9e82407233889170b78ba20560fd69927464204ebe5d8b6ed85 |