Skip to main content

No project description provided

Project description

Haiku Geometric

Overview | Installation | Quickstart | Examples | Documentation | License

Documentation Status Python application pypi

Overview

Haiku Geometric is a collection of graph neural networks (GNNs) implemented using JAX. It tries to provide object-oriented and easy-to-use modules for GNNs.

Haiku Geometric is built on top of Haiku and Jraph. It is deeply inspired by PyTorch Geometric. In most cases, Haiku Geometric tries to replicate the API of PyTorch Geometric to allow code sharing between the two.

Haiku Geometric is still under development and I would advise against using it in production.

Installation

Haiku Geometric can be installed from source:

pip install git+https://github.com/alexOarga/haiku-geometric.git

Alternatively, you can install Haiku Geometric using pip:

pip install haiku-geometric

Quickstart

For instance, we can create a simple graph convolutional network (GCN) of 2 layers as follows:

import jax
import haiku as hk
from haiku_geometric.nn import GCNConv

class GCN(hk.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(hidden_channels)
        self.conv2 = GCNConv(hidden_channels)
        self.linear = hk.Linear(out_channels)

    def __call__(self, nodes,senders, receivers):
        x = self.conv1(nodes, senders, receivers)
        x = jax.nn.relu(x)
        x = self.conv2(x, senders, receivers)
        x = self.linear(nodes)
        return x

def forward(nodes, senders, receivers):
    gcn = GCN(16, 7)
    return gcn(nodes, senders, receivers)

The GNN that we have defined is a Haiku Module. To convert our module in a function that can be used with JAX, we transform it using hk.transform as described in the Haiku documentation.

model = hk.transform(forward)
model = hk.without_apply_rng(model)
rng = jax.random.PRNGKey(42)
params = model.init(rng, nodes=nodes, senders=senders, receivers=receivers)

We can now run a forward pass on the model:

output = model.apply(params=params, nodes=nodes, senders=senders, receivers=receivers)

Documentation

The documentation for Haiku Geometric can be found here.

Examples

Haiku Geometric comes with a few examples that showcase the usage of the library. The following examples are available:

Link
Quickstart Example Open in Colab
Graph Convolution Networks with Karate Club dataset Open in Colab
Graph Attention Networks with CORA dataset Open in Colab
TopKPooling and GraphConv with PROTEINS dataset Open in Colab

Implemented GNNs modules

Currently, Haiku Geometric includes the following GNN modules:

Model Description
GCNConv Graph convolution layer from the Semi-Supervised Classification with Graph Convolutional Networks paper.
GATConv Graph attention layer from the Graph Attention Networks paper.
SAGEConv Graph convolution layer from the Inductive Representation Learning on Large Graphs paper.
GINConv Graph isomorphism network layer from the How Powerful are Graph Neural Networks? paper.
GINEConv Graph isomorphism network layer from the Strategies for Pre-training Graph Neural Networks paper.
GraphConv Graph convolution layer from the Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks paper.
GeneralConv A general GNN layer adapted from the Design Space for Graph Neural Networks paper.
GatedGraphConv Graph convolution layer from the Gated Graph Sequence Neural Networks paper.
EdgeConv Edge convolution layer from the Dynamic Graph CNN for Learning on Point Clouds paper.
PNAConv Propagation Network layer from the Principal Neighbourhood Aggregation for Graph Nets paper.
MetaLayer Meta layer from the Relational Inductive Biases, Deep Learning, and Graph Networks paper.
GPSLayer Graph layer from the Recipe for a General, Powerful, Scalable Graph Transformer paper.

Implemented positional encodings

The following positional encodings are currently available:

Model Description
LaplacianEncoder Laplacian positional encoding from the Rethinking Graph Transformers with Spectral Attention paper.
MagLaplacianEncoder Magnetic Laplacian positional encoding from the Transformers Meet Directed Graphs paper.

Issues

If you encounter any issue, please open an issue.

Running tests

Haiku Geometric can be tested using pytest by running the following command:

python -m pytest test/

License

License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

haiku_geometric-0.0.5.tar.gz (118.1 kB view hashes)

Uploaded Source

Built Distribution

haiku_geometric-0.0.5-py3-none-any.whl (75.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page