Skip to main content

Geometric Deep Learning Extension Library for TensorFlow

Project description

GraphGallery

logo


PyPI Version

A gallery of state-of-the-arts graph neural networks. Implemented with TensorFlow 2.x.

This repo aims to achieve 4 goals:

  • Similar or higher performance
  • Faster training and testing
  • Simple and convenient to use, high scalability
  • Easy to read source codes

Requirements

  • python>=3.6
  • tensorflow>=2.1 (2.1 is recommended)
  • networkx==2.3
  • scipy>=1.4.1
  • scikit_learn>=0.22
  • numpy>=1.17.4
  • numba>=0.48
  • gensim>=3.8.1

Other packages (not necessary):

  • metis==0.2a4 (required for ClusterGCN)
  • texttable

Install

pip install -U graphgallery

Implementation

General models

  • GCN from Semi-Supervised Classification with Graph Convolutional Networks 📝Paper
  • GAT from Graph Attention Networks 📝Paper
  • SGC from Simplifying Graph Convolutional Networks 📝Paper
  • GraphSAGE from Inductive Representation Learning on Large Graphs 📝Paper
  • GWNN from Graph Wavelet Neural Network 📝Paper
  • GMNN from Graph Markov Neural Networks 📝Paper
  • ChebyNet from Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering 📝Paper
  • ClusterGCN from Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks 📝Paper
  • FastGCN from FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling 📝Paper
  • LGCN from Large-Scale Learnable Graph Convolutional Networks 📝Paper

Defense models

  • RobustGCN from Robust Graph Convolutional Networks Against Adversarial Attacks 📝Paper
  • SBVAT/OBVAT from Batch Virtual Adversarial Training for Graph Convolutional Networks 📝Paper

Other custom models

  • GCN_MIX: Mixture of GCN and MLP
  • GCNF: GCN + feature
  • DenseGCN: Dense version of GCN
  • EdgeGCN: GCN using message passing framework
  • MedianSAGE: GraphSAGE using Median aggregation

Quick Start

Example of GCN model

from graphgallery.nn.models import GCN
# adj is scipy sparse matrix, x is numpy array matrix
model = GCN(adj, x, labels, device='GPU', seed=123)
# build your GCN model with custom hyper-parameters
model.build()
# train your model. here idx_train and idx_val are numpy arrays
his = model.train(idx_train, idx_val, verbose=1, epochs=100)
# test your model
loss, accuracy = model.test(idx_test)
print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')

On Cora dataset:

loss 1.02, acc 95.00%, val_loss 1.41, val_acc 77.40%: 100%|██████████| 100/100 [00:02<00:00, 37.07it/s]
Test loss 1.4123, Test accuracy 81.20%

Build your model

you can use the following statement to build your model

# one hidden layer with hidden units 32 and activation function RELU
>>> model.build(hiddens=32, activations='relu')

# two hidden layer with hidden units 32, 64 and all activation functions are RELU
>>> model.build(hiddens=[32, 64], activations='relu')

# two hidden layer with hidden units 32, 64 and activation functions RELU and ELU
>>> model.build(hiddens=[32, 64], activations=['relu', 'elu'])

# other parameters like `dropouts` and `l2_norms` (if have) are the SAME.

Train or test your model

More details can be seen in the methods model.train and model.test

Hyper-parameters

you can simply use model.show() to show all your Hyper-parameters. Otherwise you can also use model.show('model') or model.show('train') to show your model parameters and training parameters. NOTE: you should install texttable first.

Visualization

  • Accuracy
import matplotlib.pyplot as plt
plt.plot(his.history['acc'])
plt.plot(his.history['val_acc'])
plt.legend(['Accuracy', 'Val Accuracy'])
plt.xlabel('Epochs')
plt.show()

visualization

  • Loss
import matplotlib.pyplot as plt
plt.plot(his.history['loss'])
plt.plot(his.history['val_loss'])
plt.legend(['Loss', 'Val Loss'])
plt.xlabel('Epochs')
plt.show()

visualization

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

graphgallery-0.1.6.tar.gz (52.1 kB view hashes)

Uploaded Source

Built Distribution

graphgallery-0.1.6-py3-none-any.whl (111.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page