Skip to main content

Geometric Deep Learning Extension Library for TensorFlow and PyTorch

Project description

logo

TensorFLow or PyTorch? Both!

Python tensorflow pytorch pypi stars forks issues pypi

GraphGallery

GraphGallery is a gallery for benchmark graph neural networks with TensorFlow 2.x and PyTorch backend. GraphGallery 0.5.x is a total re-write from previous versions, and some things have changed.

👀 What's important

Differences between GraphGallery and Pytorch Geometric (PyG), Deep Graph Library (DGL), etc...

  • PyG and DGL are just like TensorFlow while GraphGallery is more like Keras
  • GraphGallery is more extensible and user-friendly
  • GraphGallery has high scalaribility for researchers to use

🚀 Installation

  • Build from source (latest version)
git clone https://github.com/EdisonLeeeee/GraphGallery.git
cd GraphGallery
python setup.py install
  • Or using pip (stable version)
pip install -U graphgallery

🤖 Implementations

In detail, the following methods are currently implemented:

Semi-supervised models

General models

ChebyNet from Michaël Defferrard et al, 📝Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering (NeurIPS'16) :octocat:TensorFLow Example [🔥PyTorch Example], [🔥PyG Example]
GCN from Thomas N. Kipf et al, 📝Semi-Supervised Classification with Graph Convolutional Networks (ICLR'17) [:octocat:TensorFLow Example], [🔥PyTorch Example], [🔥PyG Example] , [🔥DGL-PyTorch Example] [:octocat:DGL-TensorFlow Example]
GraphSAGE from William L. Hamilton et al, 📝Inductive Representation Learning on Large Graphs (NeurIPS'17) [:octocat:TensorFLow Example], [🔥PyTorch Example], [🔥PyG Example]
FastGCN from Jie Chen et al, 📝FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling (ICLR'18) [:octocat:TensorFLow Example], [🔥PyTorch Example], [🔥PyG Example]
LGCN from Hongyang Gao et al, 📝Large-Scale Learnable Graph Convolutional Networks (KDD'18) [:octocat:TensorFLow Example], [🔥PyTorch Example], [🔥PyG Example]
GAT from Petar Veličković et al, 📝Graph Attention Networks (ICLR'18) [:octocat:TensorFLow Example], [🔥PyTorch Example], [🔥PyG Example]
SGC from Felix Wu et al, 📝Simplifying Graph Convolutional Networks (ICLR'19) [:octocat:TensorFLow Example], [🔥PyTorch Example], [🔥PyG Example]
GWNN from Bingbing Xu et al, 📝Graph Wavelet Neural Network (ICLR'19) [:octocat:TensorFLow Example], [🔥PyTorch Example], [🔥PyG Example]
GMNN from Meng Qu et al, 📝Graph Attention Networks (ICLR'19) [:octocat:TensorFLow Example], [🔥PyTorch Example], [🔥PyG Example]
ClusterGCN from Wei-Lin Chiang et al, 📝Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks (KDD'19) [:octocat:TensorFLow Example], [🔥PyTorch Example], [🔥PyG Example]
DAGNN from Meng Liu et al, 📝Towards Deeper Graph Neural Networks (KDD'20) [:octocat:TensorFLow Example], [🔥PyTorch Example], [🔥PyG Example]

Defense models

RobustGCN from Petar Veličković et al, 📝Robust Graph Convolutional Networks Against Adversarial Attacks (KDD'19) [:octocat:TensorFLow Example], [🔥PyTorch Example], [🔥PyG Example]
SBVAT from Zhijie Deng et al, 📝Batch Virtual Adversarial Training for Graph Convolutional Networks (ICML'19) [:octocat:TensorFLow Example], [🔥PyTorch Example], [🔥PyG Example]
OBVAT from Zhijie Deng et al, 📝Batch Virtual Adversarial Training for Graph Convolutional Networks (ICML'19) [:octocat:TensorFLow Example], [🔥PyTorch Example], [🔥PyG Example]

Unsupervised models

Deepwalk from Zhijie Deng et al, 📝DeepWalk: Online Learning of Social Representations (KDD'14) [:octocat:TensorFLow Example], [🔥PyTorch Example], [🔥PyG Example]
Node2vec from Zhijie Deng et al, 📝node2vec: Scalable Feature Learning for Networks (KDD'16) [:octocat:TensorFLow Example], [🔥PyTorch Example], [🔥PyG Example]

⚡ Quick Start

Datasets

more details please refer to GraphData.

Planetoid

fixed datasets

from graphgallery.data import Planetoid
# set `verbose=False` to avoid additional outputs 
data = Planetoid('cora', verbose=False)
graph = data.graph
idx_train, idx_val, idx_test = data.split_nodes()
# idx_train:  training indices: 1D Numpy array
# idx_val:  validation indices: 1D Numpy array
# idx_test:  testing indices: 1D Numpy array
>>> graph
Graph(adj_matrix(2708, 2708), attr_matrix(2708, 2708), labels(2708,))

currently the supported datasets are:

>>> data.supported_datasets
('citeseer', 'cora', 'pubmed')

NPZDataset

more scalable datasets (stored with .npz)

from graphgallery.data import NPZDataset;
# set `verbose=False` to avoid additional outputs
data = NPZDataset('cora', verbose=False, standardize=False)
graph = data.graph
idx_train, idx_val, idx_test = data.split_nodes(random_state=42)
>>> graph
Graph(adj_matrix(2708, 2708), attr_matrix(2708, 2708), labels(2708,))

currently the supported datasets are:

>>> data.supported_datasets
('citeseer','citeseer_full','cora','cora_ml','cora_full',
 'amazon_cs','amazon_photo','coauthor_cs','coauthor_phy', 
 'polblogs', 'pubmed', 'flickr','blogcatalog','dblp')

Tensor

  • Strided (dense) Tensor
>>> backend()
TensorFlow 2.1.2 Backend

>>> from graphgallery import functional as F
>>> arr = [1, 2, 3]
>>> F.astensor(arr)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>
  • Sparse Tensor
>>> import scipy.sparse as sp
>>> sp_matrix = sp.eye(3)
>>> F.astensor(sp_matrix)
<tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7f1bbc205dd8>
  • also works for PyTorch, just like
>>> from graphgallery import set_backend
>>> set_backend('torch') # torch, pytorch or th
PyTorch 1.6.0+cu101 Backend

>>> F.astensor(arr)
tensor([1, 2, 3])

>>> F.astensor(sp_matrix)
tensor(indices=tensor([[0, 1, 2],
                       [0, 1, 2]]),
       values=tensor([1., 1., 1.]),
       size=(3, 3), nnz=3, layout=torch.sparse_coo)
  • To Numpy or Scipy sparse matrix
>>> tensor = F.astensor(arr)
>>> F.tensoras(tensor)
array([1, 2, 3])

>>> sp_tensor = F.astensor(sp_matrix)
>>> F.tensoras(sp_tensor)
<3x3 sparse matrix of type '<class 'numpy.float32'>'
    with 3 stored elements in Compressed Sparse Row format>
  • Or even convert one Tensor to another one
>>> tensor = F.astensor(arr, backend="tensorflow") # or "tf" in short
>>> tensor
<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>
>>> F.tensor2tensor(tensor)
tensor([1, 2, 3])

>>> sp_tensor = F.astensor(sp_matrix, backend="tensorflow") # set backend="tensorflow" to convert to tensorflow tensor
>>> sp_tensor
<tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7efb6836a898>
>>> F.tensor2tensor(sp_tensor)
tensor(indices=tensor([[0, 1, 2],
                       [0, 1, 2]]),
       values=tensor([1., 1., 1.]),
       size=(3, 3), nnz=3, layout=torch.sparse_coo)

Example of GCN model

from graphgallery.nn.gallery import GCN

model = GCN(graph, attr_transform="normalize_attr", device="CPU", seed=123)
# build your GCN model with default hyper-parameters
model.build()
# train your model. here idx_train and idx_val are numpy arrays
# verbose takes 0, 1, 2, 3, 4
his = model.train(idx_train, idx_val, verbose=1, epochs=100)
# test your model
# verbose takes 0, 1, 2
loss, accuracy = model.test(idx_test, verbose=1)
print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')

On Cora dataset:

Training...
100/100 [==============================] - 1s 14ms/step - loss: 1.0161 - acc: 0.9500 - val_loss: 1.4101 - val_acc: 0.7740 - time: 1.4180
Testing...
1/1 [==============================] - 0s 62ms/step - test_loss: 1.4123 - test_acc: 0.8120 - time: 0.0620
Test loss 1.4123, Test accuracy 81.20%

Customization

  • Build your model you can use the following statement to build your model
# one hidden layer with hidden units 32 and activation function RELU
>>> model.build(hiddens=32, activations='relu')

# two hidden layer with hidden units 32, 64 and all activation functions are RELU
>>> model.build(hiddens=[32, 64], activations='relu')

# two hidden layer with hidden units 32, 64 and activation functions RELU and ELU
>>> model.build(hiddens=[32, 64], activations=['relu', 'elu'])
  • Train your model
# train with validation
>>> his = model.train(idx_train, idx_val, verbose=1, epochs=100)
# train without validation
>>> his = model.train(idx_train, verbose=1, epochs=100)

here his is a tensorflow History instance.

  • Test you model
>>> loss, accuracy = model.test(idx_test, verbose=1)
Testing...
1/1 [==============================] - 0s 62ms/step - test_loss: 1.4123 - test_acc: 0.8120 - time: 0.0620
>>> print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')
Test loss 1.4123, Test accuracy 81.20%

Visualization

NOTE: you must install SciencePlots package for a better preview.

import matplotlib.pyplot as plt
with plt.style.context(['science', 'no-latex']):
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    axes[0].plot(his.history['acc'], label='Training accuracy', linewidth=3)
    axes[0].plot(his.history['val_acc'], label='Validation accuracy', linewidth=3)
    axes[0].legend(fontsize=20)
    axes[0].set_title('Accuracy', fontsize=20)
    axes[0].set_xlabel('Epochs', fontsize=20)
    axes[0].set_ylabel('Accuracy', fontsize=20)

    axes[1].plot(his.history['loss'], label='Training loss', linewidth=3)
    axes[1].plot(his.history['val_loss'], label='Validation loss', linewidth=3)
    axes[1].legend(fontsize=20)
    axes[1].set_title('Loss', fontsize=20)
    axes[1].set_xlabel('Epochs', fontsize=20)
    axes[1].set_ylabel('Loss', fontsize=20)

    plt.autoscale(tight=True)
    plt.show()        

visualization

Using TensorFlow/PyTorch Backend

>>> import graphgallery
>>> graphgallery.backend()
TensorFlow 2.1.0 Backend

>>> graphgallery.set_backend("pytorch")
PyTorch 1.6.0+cu101 Backend

GCN using PyTorch backend

# The following codes are the same with TensorFlow Backend
>>> from graphgallery.nn.gallery import GCN
>>> model = GCN(graph, attr_transform="normalize_attr", device="GPU", seed=123);
>>> model.build()
>>> his = model.train(idx_train, idx_val, verbose=1, epochs=100)
Training...
100/100 [==============================] - 0s 5ms/step - loss: 0.6813 - acc: 0.9214 - val_loss: 1.0506 - val_acc: 0.7820 - time: 0.4734
>>> loss, accuracy = model.test(idx_test, verbose=1)
Testing...
1/1 [==============================] - 0s 1ms/step - test_loss: 1.0131 - test_acc: 0.8220 - time: 0.0013
>>> print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')
Test loss 1.0131, Test accuracy 82.20%

❓ How to add your datasets

This is motivated by gnn-benchmark

from graphgallery.data import Graph

# Load the adjacency matrix A, attribute matrix X and labels vector y
# A - scipy.sparse.csr_matrix of shape [n_nodes, n_nodes]
# X - scipy.sparse.csr_matrix or np.ndarray of shape [n_nodes, n_atts]
# y - np.ndarray of shape [n_nodes]

mydataset = Graph(adj_matrix=A, attr_matrix=X, labels=y)
# save dataset
mydataset.to_npz('path/to/mydataset.npz')
# load dataset
mydataset = Graph.from_npz('path/to/mydataset.npz')

❓ How to define your models

You can follow the codes in the folder graphgallery.nn.gallery and write you models based on:

  • TensorFlow
  • PyTorch
  • PyTorch Geometric (PyG)
  • Deep Graph Library (DGL)

NOTE: PyG backend and DGL backend now are supported in GraphGallery!

>>> import graphgallery
>>> graphgallery.set_backend("pyg")
PyTorch Geometric 1.6.1 (PyTorch 1.6.0+cu101) Backend

GCN using PyG backend

# The following codes are the same with TensorFlow or PyTorch Backend
>>> from graphgallery.nn.gallery import GCN
>>> model = GCN(graph, attr_transform="normalize_attr", device="GPU", seed=123);
>>> model.build()
>>> his = model.train(idx_train, idx_val, verbose=1, epochs=100)
Training...
100/100 [==============================] - 0s 3ms/step - loss: 0.5325 - acc: 0.9643 - val_loss: 1.0034 - val_acc: 0.7980 - time: 0.3101
>>> loss, accuracy = model.test(idx_test, verbose=1)
Testing...
1/1 [==============================] - 0s 834us/step - test_loss: 0.9733 - test_acc: 0.8130 - time: 8.2737e-04
>>> print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')
Test loss 0.97332, Test accuracy 81.30%

😎 More Examples

Please refer to the examples directory.

⭐ Road Map

  • Add PyTorch models support
  • Add other frameworks (PyG and DGL) support
  • Add more GNN models (TF and Torch backend)
  • Support for more tasks, e.g., graph Classification and link prediction
  • Support for more types of graphs, e.g., Heterogeneous graph
  • Add Docstrings and Documentation (Building)

😘 Acknowledgement

This project is motivated by Pytorch Geometric, Tensorflow Geometric, Stellargraph and DGL, etc., and the original implementations of the authors, thanks for their excellent works!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

graphgallery-0.5.0.tar.gz (93.1 kB view details)

Uploaded Source

Built Distribution

graphgallery-0.5.0-py3-none-any.whl (187.7 kB view details)

Uploaded Python 3

File details

Details for the file graphgallery-0.5.0.tar.gz.

File metadata

  • Download URL: graphgallery-0.5.0.tar.gz
  • Upload date:
  • Size: 93.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.6.0 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.7.3

File hashes

Hashes for graphgallery-0.5.0.tar.gz
Algorithm Hash digest
SHA256 edac7ae9ebe51b57de46aeb2522124873e7b8f85338371011d1baffe9eb8ceac
MD5 e737d76ff4c588b65fe05a43de018d14
BLAKE2b-256 960c891c953ffd6d23b318e4890bb789ef55c27fe0c24f5efa139a26b070d63b

See more details on using hashes here.

File details

Details for the file graphgallery-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: graphgallery-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 187.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.6.0 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.7.3

File hashes

Hashes for graphgallery-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8befe18e2ed894362be74dc00ca58fc506855e126f7c296550c63794612ad262
MD5 2d01fe0cfe8058c6321ed7bbed3c3730
BLAKE2b-256 981fc2f4c8287b51d0d9159bc38696e649f015a8d5981c69d11dd37040c51cde

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page