Skip to main content

Geometric Deep Learning Extension Library for TensorFlow

Project description

GraphGallery

logo


Python 3.6 TensorFlow >=2.1 PyPI Version GitHub license

GraphGallery is a gallery of state-of-the-arts graph neural networks for TensorFlow 2.x.

This repo aims to achieve 4 goals:

  • Similar or higher performance
  • Faster training and testing
  • Simple and convenient to use, high scalability
  • Easy to read source codes

Installation

pip install -U graphgallery

Implementations

In detail, the following methods are currently implemented:

Semi-supervised models

General

Defense models

Unsupervised models

Quick Start

Datasets

from graphgallery.data import Planetoid
# set `verbose=False` to avoid these printed tables
data = Planetoid('cora', verbose=False)
adj, x, labels = data.graph.unpack()
idx_train, idx_val, idx_test = data.split()
# adj:  adjacency matrix: 2D Scipy sparse matrix
# x:  feature matrix: 2D Numpy array
# labels:  class labels: 1D Numpy array
# idx_train:  training indices: 1D Numpy array
# idx_val:  validation indices: 1D Numpy array
# idx_test:  testing indices: 1D Numpy array

currently the supported datasets are:

>>> data.supported_datasets
('citeseer', 'cora', 'pubmed')

Example of GCN model

from graphgallery.nn.models import GCN
# adj is scipy sparse matrix, x is numpy array matrix
model = GCN(adj, x, labels, device='GPU', norm_x='l1', seed=123)
# build your GCN model with default hyper-parameters
model.build()
# train your model. here idx_train and idx_val are numpy arrays
his = model.train(idx_train, idx_val, verbose=1, epochs=100)
# test your model
loss, accuracy = model.test(idx_test)
print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')

On Cora dataset:

<Loss = 1.0161 Acc = 0.9500 Val_Loss = 1.4101 Val_Acc = 0.7740 >: 100%|██████████| 100/100 [00:01<00:00, 118.02it/s]
Test loss 1.4123, Test accuracy 81.20%

Customization

  • Build your model you can use the following statement to build your model
# one hidden layer with hidden units 32 and activation function RELU
>>> model.build(hiddens=32, activations='relu')

# two hidden layer with hidden units 32, 64 and all activation functions are RELU
>>> model.build(hiddens=[32, 64], activations='relu')

# two hidden layer with hidden units 32, 64 and activation functions RELU and ELU
>>> model.build(hiddens=[32, 64], activations=['relu', 'elu'])

# other parameters like `dropouts` and `l2_norms` (if have) are the SAME.
  • Train your model
# train with validation
>>> his = model.train(idx_train, idx_val, verbose=1, epochs=100)
# train without validation
# his = model.train(idx_train, verbose=1, epochs=100)

here his is tensorflow Histoory like instance (or itself).

  • Test you model
loss, accuracy = model.test(idx_test)
print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')
  • Display hyper-parameters

You can simply use model.show() to show all your Hyper-parameters. Otherwise you can also use model.show('model') or model.show('train') to show your model parameters and training parameters.

NOTE: you should install texttable first.

Visualization

NOTE: you must install SciencePlots package for a better preview.

  • Accuracy
import matplotlib.pyplot as plt
with plt.style.context(['science', 'no-latex']):
    plt.plot(his.history['acc'])
    plt.plot(his.history['val_acc'])
    plt.legend(['Train Accuracy', 'Val Accuracy'])
    plt.ylabel('Accuracy')
    plt.xlabel('Epochs')
    plt.autoscale(tight=True)
    plt.show()    

visualization

  • Loss
import matplotlib.pyplot as plt
with plt.style.context(['science', 'no-latex']):
    plt.plot(his.history['loss'])
    plt.plot(his.history['val_loss'])
    plt.legend(['Train Loss', 'Val Loss'])
    plt.ylabel('Loss')
    plt.xlabel('Epochs')
    plt.autoscale(tight=True)
    plt.show()    

visualization

More Examples

Please refer to the examples directory.

TODO Lists

  • Add Docstrings and Documentation
  • Add PyTorch models support
  • Support for graph Classification and link prediction tasks
  • Support for Heterogeneous graphs

Acknowledgement

This project is motivated by Pytorch Geometric, Tensorflow Geometric and Stellargraph, and the original implementations of the authors, thanks for their excellent works!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

graphgallery-0.2.0.tar.gz (67.7 kB view hashes)

Uploaded Source

Built Distribution

graphgallery-0.2.0-py3-none-any.whl (132.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page