A hackable TensorFlow GraphNets library
Project description
tf_gnns - A Hackable GraphNets library
A library for easy construction of message-passing networks in Keras 3.
It is largely inspired by this DeepMind paper and the corresponding open-source library (original graph_nets library). In addition it contains baseline tested implementations for GCNs.
The tf_gnns library is backend-agnostic through Keras 3. TensorFlow,
PyTorch, and JAX backends are tested in CI for core tensor operations and
high-level MPNN/GCN smoke tests. The TensorFlow compatibility matrix currently
covers TensorFlow 2.17 through 2.21, and the Torch backend smoke matrix covers
Torch 2.10, 2.11, and 2.12 CPU wheels.
Initial motivation
This library was initially implemented for GraphNet-style MPNNs and all the other related architectures that can be seen as special cases of Graphnets.
The tf_gnns computations are structured to make them amenable to backend graph compilers such as XLA for TensorFlow/JAX, and torch.compile where applicable.
Note that torch paths are at the moment less performant than torch geometric for eager computation mode.
tf_gnns is built to support arbitrary node/edge/global attributes and update functions.
Philosophy and performance
This tf_gnns framework explicitly avoids containing custom low-level kernels, as framework operations are often more maintainable and well performing.
The tf_gnns computations are formed in a way that allow the graph compilers to create optimized GPU code, that can be reasonably expected to perform as good as the underlying keras-compiled code allows.
Often this is close to the limits of the capabilities of the accelerators.
The framework takes advantage of graph computation where available (e.g., jax and tensorflow mainly) for creating fused networks. In preliminary benchmarks speedups with torch.compile were not observed.
A set of utility functions for MLP construction with Keras is also provided (i.e., handling input/output sizes for valid networks), replacing Sonnet.
TensorFlow compatibility and test status
| TensorFlow | TensorFlow Probability | Status |
|---|---|---|
| 2.17.x | 0.24.x | |
| 2.18.x | 0.25.x | |
| 2.19.x | 0.25.x | |
| 2.20.x | 0.25.x | |
| 2.21.x | 0.25.x |
The matrix above is validated by scripts/run_tf_matrix_tests.sh and in CI (.github/workflows/tests.yml).
Installing tf_gnns
NOTE
The current tested matrix is TensorFlow 2.17 through 2.21 with the matching TensorFlow Probability versions shown above.
Install with uv (recommended):
uv sync
Or install with pip:
# optional - recommended:
# pip install tensorflow==2.15
# pip install tensorflow_probability==0.22
pip install tf_gnns
Run tests:
uv sync --group dev
uv run pytest -v
Run tests with coverage and update badge payload:
scripts/run_coverage.sh
Run compatibility tests across TensorFlow versions:
scripts/run_tf_matrix_tests.sh 2.17 2.18 2.19 2.20 2.21
Execution and compilation
tf_gnns execution paths are eager by default so they can remain backend-portable with Keras 3.
If you are using the TensorFlow backend and want graph compilation, compile at the application level:
import tensorflow as tf
from tf_gnns.models.graphnet import GraphNetMLP
model = GraphNetMLP(units=32, core_steps=2)
@tf.function
def train_step(graph_tensor_dict):
with tf.GradientTape() as tape:
out = model(graph_tensor_dict)
loss = tf.reduce_mean(out["nodes"]) # example loss
grads = tape.gradient(loss, model.trainable_variables)
return loss, grads
This keeps library internals backend-agnostic while still allowing TensorFlow users to optimize execution.
Backend support note
tf_gnns targets Keras 3 backend portability. The main CI matrix validates:
- TensorFlow backend across TensorFlow 2.17, 2.18, 2.19, 2.20, and 2.21.
- Torch backend smoke tests across Torch 2.10, 2.11, and 2.12 CPU wheels.
- JAX execution through backend-agnostic smoke tests when JAX is available.
For Torch, the CI job installs the selected CPU Torch wheel last and runs with
uv run --no-sync so the version under test is not replaced by uv's lockfile
sync. GPU-enabled Torch/Triton stacks can be more fragile because Keras, Torch,
TensorFlow, and JAX may pull different CUDA dependency versions.
Recommended Torch CPU smoke setup:
pip install --index-url https://download.pytorch.org/whl/cpu "torch==2.12.0"
KERAS_BACKEND=torch pytest -q tests/test_torch_backend_runtime.py tests/test_notebook_torch_backend_flow.py
If you are using a GPU Torch/Triton combo and hit import-time crashes in
triton / torch._dynamo, first reproduce with CPU wheels or isolate the
backend in a clean environment.
Build the Docker test image for a specific TensorFlow version:
docker build --build-arg TENSORFLOW_VERSION=2.17 -t tf-gnns:test .
Use through Docker
You can build a Docker image that uses tf_gnns with the following command, based on Ubuntu 22:
docker build . -t tf_gnns_215 --network host --build-arg TENSORFLOW_VERSION=2.15
The container implements some logic to sort out the necessary dependencies. Namely,
- Numpy 1.x is required for tf <= 2.14
- Keras 2 support needs to be enabled for tf >= 2.16
- The
tensorflow_probabilityversion is selected through a mapping given the tensorflow version.
Examples
tf_gnns basics
You can inspect some basic functionality in the following Colab notebook:
List sorting example
(Example from the original deepmind/graph_nets library)
If you are familiar with the original graph_nets library, this example will help you understand how you can transition to tf_gnns.
Sort a list of elements. This notebook and the accompanying code demonstrates how to use the Graph Nets library to learn to sort a list of elements.
A list of elements is treated as a fully connected graph between the elements. The network is trained to label the start node, and which (directed) edges correspond to the links to the next largest element, for each node.
After training, prediction ability is tested by comparing output to true sorted lists. Then the network's ability to generalize is tested by using it to sort larger lists.
Protein-Protein Interaction example
This example shows how to adapt torch_geometric (aka PyG) inputs to tf_gnns inputs.
The notebook can be run end-to-end in Google Colab, and out of the box it gives a test-set F1 score that is competitive with SOTA.
Keras 3 + Torch backend example
This example demonstrates using the higher-level GraphNet constructs with Keras 3 configured for the PyTorch backend.
GCN models
tf_gnns includes sparse GCN implementations for node-classification workloads:
SparseGCNConv: low-level sparse graph convolution layer.SparseGCN: stacked sparse GCN model.GCNv2: tunedGNN-style high-level GCN stack with residual paths, normalization, and configurable dropout.
See the OGBN-Arxiv examples:
notebooks/06_gcn_ogbn_arxiv_tfgnns.ipynb(tf_gnns GCN training workflow, including tunedGNN-style configuration)
Performance
From initial tests, the performance of tf_gnns seems to be at least as good as deepmind/graph_nets when using tensor dictionaries.
Publications using tf_gnns
The library has been used so far in the following publications:
[1] Bayesian graph neural networks for strain-based crack localization
[2] Remaining Useful Life Estimation Under Uncertainty with Causal GraphNets
[3] Relational VAE: A Continuous Latent Variable Model for Graph Structured Data
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tf_gnns-0.3.0.tar.gz.
File metadata
- Download URL: tf_gnns-0.3.0.tar.gz
- Upload date:
- Size: 42.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
96140354f1b804a77b6ff2a75a114bf90a8ef01d0a63a7bdc5734d76026db44b
|
|
| MD5 |
d68ba7a267e626367c92804eb690639c
|
|
| BLAKE2b-256 |
5b1890b08da3ce66cf915e2f0802a4e10be45365787a52be38b011996db659c7
|
Provenance
The following attestation bundles were made for tf_gnns-0.3.0.tar.gz:
Publisher:
release-pypi.yml on mylonasc/tf_gnns
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
tf_gnns-0.3.0.tar.gz -
Subject digest:
96140354f1b804a77b6ff2a75a114bf90a8ef01d0a63a7bdc5734d76026db44b - Sigstore transparency entry: 1850162997
- Sigstore integration time:
-
Permalink:
mylonasc/tf_gnns@bb1ea48ec019e4ea3a3e7c526744f8db7e14e4c1 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/mylonasc
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release-pypi.yml@bb1ea48ec019e4ea3a3e7c526744f8db7e14e4c1 -
Trigger Event:
workflow_dispatch
-
Statement type:
File details
Details for the file tf_gnns-0.3.0-py3-none-any.whl.
File metadata
- Download URL: tf_gnns-0.3.0-py3-none-any.whl
- Upload date:
- Size: 45.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
087cd15ded9aa4e48f5e503fdb3c7d111b87a7d3a4d180deaf77dae8e6b49b40
|
|
| MD5 |
94460b9fe33da94f1641d991de7197da
|
|
| BLAKE2b-256 |
ad30ca63f9dcda3bc26a4e20c4152420d30f9f2cea23eb84a218f4114c519936
|
Provenance
The following attestation bundles were made for tf_gnns-0.3.0-py3-none-any.whl:
Publisher:
release-pypi.yml on mylonasc/tf_gnns
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
tf_gnns-0.3.0-py3-none-any.whl -
Subject digest:
087cd15ded9aa4e48f5e503fdb3c7d111b87a7d3a4d180deaf77dae8e6b49b40 - Sigstore transparency entry: 1850163171
- Sigstore integration time:
-
Permalink:
mylonasc/tf_gnns@bb1ea48ec019e4ea3a3e7c526744f8db7e14e4c1 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/mylonasc
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release-pypi.yml@bb1ea48ec019e4ea3a3e7c526744f8db7e14e4c1 -
Trigger Event:
workflow_dispatch
-
Statement type: