No project description provided
Project description
Haiku Geometric
Overview | Installation | Quickstart | Examples | Documentation | License
Overview
Haiku Geometric is a collection of graph neural networks (GNNs) implemented using JAX. It tries to provide object-oriented and easy-to-use modules for GNNs.
Haiku Geometric is built on top of Haiku and Jraph. It is deeply inspired by PyTorch Geometric. In most cases, Haiku Geometric tries to replicate the API of PyTorch Geometric to allow code sharing between the two.
Haiku Geometric is still under development and I would advise against using it in production.
Installation
Haiku Geometric can be installed from source:
pip install git+https://github.com/alexOarga/haiku-geometric.git
Alternatively, you can install Haiku Geometric using pip:
pip install haiku-geometric
Quickstart
For instance, we can create a simple graph convolutional network (GCN) of 2 layers as follows:
import jax
import haiku as hk
from haiku_geometric.nn import GCNConv
class GCN(hk.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(hidden_channels)
self.conv2 = GCNConv(hidden_channels)
self.linear = hk.Linear(out_channels)
def __call__(self, nodes,senders, receivers):
x = self.conv1(nodes, senders, receivers)
x = jax.nn.relu(x)
x = self.conv2(x, senders, receivers)
x = self.linear(nodes)
return x
def forward(nodes, senders, receivers):
gcn = GCN(16, 7)
return gcn(nodes, senders, receivers)
The GNN that we have defined is a Haiku Module.
To convert our module in a function that can be used with JAX, we transform
it using hk.transform
as described in the
Haiku documentation.
model = hk.transform(forward)
model = hk.without_apply_rng(model)
rng = jax.random.PRNGKey(42)
params = model.init(rng, nodes=nodes, senders=senders, receivers=receivers)
We can now run a forward pass on the model:
output = model.apply(params=params, nodes=nodes, senders=senders, receivers=receivers)
Documentation
The documentation for Haiku Geometric can be found here.
Examples
Haiku Geometric comes with a few examples that showcase the usage of the library. The following examples are available:
Link | |
---|---|
Quickstart Example | |
Graph Convolution Networks with Karate Club dataset | |
Graph Attention Networks with CORA dataset | |
TopKPooling and GraphConv with PROTEINS dataset |
Implemented GNNs modules
Currently, Haiku Geometric includes the following GNN modules:
Implemented positional encodings
The following positional encodings are currently available:
Model | Description |
---|---|
LaplacianEncoder | Laplacian positional encoding from the Rethinking Graph Transformers with Spectral Attention paper. |
MagLaplacianEncoder | Magnetic Laplacian positional encoding from the Transformers Meet Directed Graphs paper. |
Issues
If you encounter any issue, please open an issue.
Running tests
Haiku Geometric can be tested using pytest
by running the following command:
python -m pytest test/
License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file haiku_geometric-0.0.5.tar.gz
.
File metadata
- Download URL: haiku_geometric-0.0.5.tar.gz
- Upload date:
- Size: 118.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.16
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3d3e9cb8412f1d0f4cb5e95d98d0c86d3c3b87401bfb64f67fdeff9b7b7963d7 |
|
MD5 | e21ff4377ad6d8f7fbec13c23e56ed7f |
|
BLAKE2b-256 | 63316291b06bc36873b79803dc9eaea88e0a8bdc5062b9595c41dc5b8b82f178 |
File details
Details for the file haiku_geometric-0.0.5-py3-none-any.whl
.
File metadata
- Download URL: haiku_geometric-0.0.5-py3-none-any.whl
- Upload date:
- Size: 75.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.16
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e0f2365719a5c7db71aec1e5a3a33694876eeacd4da1b65ded56bc624355b387 |
|
MD5 | 60cce99b62affcc9b613b8cd06448bd6 |
|
BLAKE2b-256 | 5a03f4644548af56930076d20b128a82b4cd7854faf1770d2e7379a58efff802 |