Skip to main content

Geometric Deep Learning Extension Library for TensorFlow

Project description

GraphGallery

logo


Python tensorflow PyPI Version License: MIT

A gallery of state-of-the-arts graph neural networks. Implemented with TensorFlow 2.x.

This repo aims to achieve 4 goals:

  • Similar or higher performance
  • Faster training and testing
  • Simple and convenient to use, high scalability
  • Easy to read source codes

Requirements

  • python>=3.6
  • tensorflow>=2.1 (2.1 is recommended)
  • networkx==2.3
  • scipy
  • scikit_learn
  • numpy
  • numba
  • gensim

Other packages (not necessary):

  • metis==0.2a4 (required for ClusterGCN)
  • texttable

Installation

pip install -U graphgallery

Implementation

General models

Defense models

Quick Start

Example of GCN model

from graphgallery.nn.models import GCN
# adj is scipy sparse matrix, x is numpy array matrix
model = GCN(adj, x, labels, device='GPU', norm_x='l1', seed=123)
# build your GCN model with custom hyper-parameters
model.build()
# train your model. here idx_train and idx_val are numpy arrays
his = model.train(idx_train, idx_val, verbose=1, epochs=100)
# test your model
loss, accuracy = model.test(idx_test)
print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')

On Cora dataset:

loss 1.02, acc 95.00%, val_loss 1.41, val_acc 77.40%: 100%|██████████| 100/100 [00:02<00:00, 37.07it/s]
Test loss 1.4123, Test accuracy 81.20%

Build your model

you can use the following statement to build your model

# one hidden layer with hidden units 32 and activation function RELU
>>> model.build(hiddens=32, activations='relu')

# two hidden layer with hidden units 32, 64 and all activation functions are RELU
>>> model.build(hiddens=[32, 64], activations='relu')

# two hidden layer with hidden units 32, 64 and activation functions RELU and ELU
>>> model.build(hiddens=[32, 64], activations=['relu', 'elu'])

# other parameters like `dropouts` and `l2_norms` (if have) are the SAME.

Train or test your model

More details can be seen in the methods model.train and model.test

Hyper-parameters

you can simply use model.show() to show all your Hyper-parameters. Otherwise you can also use model.show('model') or model.show('train') to show your model parameters and training parameters. NOTE: you should install texttable first.

Visualization

NOTE: you must install SciencePlots package for a better preview.

  • Accuracy
import matplotlib.pyplot as plt
with plt.style.context(['science', 'no-latex']):
    plt.plot(his.history['acc'])
    plt.plot(his.history['val_acc'])
    plt.legend(['Train Accuracy', 'Val Accuracy'])
    plt.ylabel('Accuracy')
    plt.xlabel('Epochs')
    plt.autoscale(tight=True)
    plt.show()    

visualization

  • Loss
import matplotlib.pyplot as plt
with plt.style.context(['science', 'no-latex']):
    plt.plot(his.history['loss'])
    plt.plot(his.history['val_loss'])
    plt.legend(['Train Loss', 'Val Loss'])
    plt.ylabel('Loss')
    plt.xlabel('Epochs')
    plt.autoscale(tight=True)
    plt.show()    

visualization

Acknowledgement

This project is motivated by Pytorch Geometric, Tensorflow Geometric and Stellargraph, and the original implementations from the authors, thanks for their excellent works!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

graphgallery-0.1.10.tar.gz (58.3 kB view hashes)

Uploaded Source

Built Distribution

graphgallery-0.1.10-py3-none-any.whl (114.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page