Steerable E(3) GNN in jax
Project description
Steerable E(3) GNN in jax
Reimplementation of SEGNN in jax. Original work by Johannes Brandstetter, Rob Hesselink, Elise van der Pol, Erik Bekkers and Max Welling.
Why jax?
40-50% faster inference and training compared to the original torch implementation. Also JAX-MD.
Installation
python -m pip install segnn-jax
Or clone this repository and build locally
python -m pip install -e .
GPU support
Upgrade jax
to the gpu version
pip install --upgrade "jax[cuda]>=0.4.6" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Validation
N-body (charged and gravity) and QM9 datasets are included for completeness from the original paper.
Results
Charged is on 5 bodies, gravity on 100 bodies. QM9 has graphs of variable sizes, so in jax samples are padded to the maximum size. Loss is MSE for Charged and Gravity and MAE for QM9.
Times are remeasured on Quadro RTX 4000, model only on batches of 100 graphs, in (global) single precision.
torch (original) | jax (ours) | |||
Loss | Inference [ms] | Loss | Inference [ms] | |
charged (position) |
.0043 | 21.22 | .0045 | 3.77 |
gravity (position) |
.265 | 60.55 | .264 | 41.72 |
QM9 (alpha) |
.066* | 82.53 | .082 | 105.98** |
** padded (naive)
Validation install
The experiments are only included in the github repo, so it needs to be cloned first.
git clone https://github.com/gerkone/segnn-jax
They are adapted from the original implementation, so additionally torch
and torch_geometric
are needed (cpu versions are enough).
python -m pip install -r experiments/requirements.txt
Datasets
QM9 is automatically downloaded and processed when running the respective experiment.
The N-body datasets have to be generated locally from the directory experiments/nbody/data (it will take some time, especially n-body gravity
)
Charged dataset (5 bodies, 10000 training samples)
python3 -u generate_dataset.py --simulation=charged --seed=43
Gravity dataset (100 bodies, 10000 training samples)
python3 -u generate_dataset.py --simulation=gravity --n-balls=100 --seed=43
Notes
On jax<=0.4.6
, the jit
-pjit
merge can be deactivated making traning faster (on nbody). This looks like an issue with dataloading and the validation training loop implementation and it does not affect SEGNN.
export JAX_JIT_PJIT_API_MERGE=0
Usage
N-body (charged)
python validate.py --dataset=charged --epochs=200 --max-samples=3000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12
N-body (gravity)
python validate.py --dataset=gravity --epochs=100 --target=pos --max-samples=10000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12 --neighbours=5 --n-bodies=100
QM9
python validate.py --dataset=qm9 --epochs=1000 --target=alpha --lmax-hidden=2 --lmax-attributes=3 --layers=7 --units=128 --norm=instance --batch-size=128 --lr=5e-4 --weight-decay=1e-8 --lr-scheduling
(configurations used in validation)
Acknowledgments
- e3nn_jax made this reimplementation possible.
- Artur Toshev and Johannes Brandsetter, for support.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file segnn_jax-0.7.tar.gz
.
File metadata
- Download URL: segnn_jax-0.7.tar.gz
- Upload date:
- Size: 12.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2ac2ae0a959d3128d13ebbfd74134e569f9d541e2954edee52db05631d76ba72 |
|
MD5 | 9a5bfd705326434d5d6c3d4a56e4c2ec |
|
BLAKE2b-256 | 2a1eb96670220ece8a2633aaac867fdae0eb820418fde05580bbce9580581988 |
File details
Details for the file segnn_jax-0.7-py3-none-any.whl
.
File metadata
- Download URL: segnn_jax-0.7-py3-none-any.whl
- Upload date:
- Size: 12.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | bd6b9e80b43d24dbce8f998d3f4e46aa188658f6c98ca7f54c4bcbb1b4846368 |
|
MD5 | 8cc30a8e9a89ff72133208505bce305c |
|
BLAKE2b-256 | b112b66c07021808d2b17b43ca341f93093c03805738c2e97eb79bb93e32f550 |