Skip to main content

Geometric Deep Learning Extension Library for TensorFlow

Project description

GraphGallery

PyPI Version

A gallery of state-of-the-arts graph neural networks. Implemented with TensorFlow 2.x.

This repo aims to achieve 4 goals:

  • Similar or higher performance
  • Faster training and testing
  • Simple and convenient to use, high scalability
  • Easy to read source codes

Requirements

  • python>=3.6
  • tensorflow>=2.1 (2.1 is recommended)
  • networkx==2.3
  • scipy>=1.4.1
  • scikit_learn>=0.22
  • numpy>=1.17.4
  • numba>=0.48
  • gensim>=3.8.1

Other packages (not necessary):

  • metis==0.2a4 (used for ClusterGCN)
  • texttable

Install

pip install -U graphgallery

Implementation

General models

  • GCN from Semi-Supervised Classification with Graph Convolutional Networks 🌐Paper
  • GAT from Graph Attention Networks 🌐Paper
  • SGC from Simplifying Graph Convolutional Networks 🌐Paper
  • GraphSAGE from Inductive Representation Learning on Large Graphs 🌐Paper
  • GWNN from Graph Wavelet Neural Network 🌐Paper
  • GMNN from Graph Markov Neural Networks 🌐Paper
  • ChebyNet from Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering 🌐Paper
  • ClusterGCN from Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks 🌐Paper
  • FastGCN from FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling 🌐Paper
  • LGCN from Large-Scale Learnable Graph Convolutional Networks 🌐Paper

Defense models

  • RGCN from Robust Graph Convolutional Networks Against Adversarial Attacks 🌐Paper
  • SBVAT/OBVAT from Batch Virtual Adversarial Training for Graph Convolutional Networks 🌐Paper

Other models

  • GCN_MIX: Mixture of GCN and MLP
  • GCNF: GCN + feature
  • DenseGCN: Dense version of GCN
  • EdgeGCN: GCN using message passing framework
  • MedianSAGE: GraphSAGE using Median aggregation

Quick Start

Train a GCN model

from graphgallery.nn.models import GCN
# adj is scipy sparse matrix, x is numpy array matrix
model = GCN(adj, x, labels, device='GPU', seed=123)
# build your GCN model with custom hyper-parameters
model.build()
# train your model. here idx_train and idx_val are numpy arrays
his = model.train(idx_train, idx_val, verbose=1, epochs=100)
# test your model
loss, accuracy = model.test(idx_test)
print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')

On Cora dataset:

loss 1.02, acc 95.00%, val_loss 1.41, val_acc 77.40%: 100%|██████████| 100/100 [00:02<00:00, 37.07it/s]
Test loss 1.4123, Test accuracy 81.20%

Visualization

  • Accuracy
import matplotlib.pyplot as plt
plt.plot(his.history['acc'])
plt.plot(his.history['val_acc'])
plt.legend(['Accuracy', 'Val Accuracy'])
plt.xlabel('Epochs')

visualization

  • Loss
import matplotlib.pyplot as plt
plt.plot(his.history['loss'])
plt.plot(his.history['val_loss'])
plt.legend(['Loss', 'Val Loss'])
plt.xlabel('Epochs')

visualization

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

graphgallery-0.1.5.tar.gz (50.5 kB view hashes)

Uploaded Source

Built Distribution

graphgallery-0.1.5-py3-none-any.whl (121.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page