Skip to main content

Steerable E(3) GNN in jax

Project description

Steerable E(3) GNN in jax

Reimplementation of SEGNN in jax. Original work by Johannes Brandstetter, Rob Hesselink, Elise van der Pol, Erik Bekkers and Max Welling.

Why jax?

40-50% faster inference and training compared to the original torch implementation. Also JAX-MD.

Installation

python -m pip install segnn-jax

Or clone this repository and build locally

python -m pip install -e .

GPU support

Upgrade jax to the gpu version

pip install --upgrade "jax[cuda]>=0.4.6" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Validation

N-body (charged and gravity) and QM9 datasets are included for completeness from the original paper.

Results

Charged is on 5 bodies, gravity on 100 bodies. QM9 has graphs of variable sizes, so in jax samples are padded to the maximum size. Loss is MSE for Charged and Gravity and MAE for QM9.

Times are remeasured on Quadro RTX 4000, model only on batches of 100 graphs, in (global) single precision.

torch (original) jax (ours)
Loss Inference [ms] Loss Inference [ms]
charged (position) .0043 21.22 .0045 3.77
gravity (position) .265 60.55 .264 41.72
QM9 (alpha) .066* 82.53 .082 105.98**
* rerun on same conditions

** padded (naive)

Validation install

The experiments are only included in the github repo, so it needs to be cloned first.

git clone https://github.com/gerkone/segnn-jax

They are adapted from the original implementation, so additionally torch and torch_geometric are needed (cpu versions are enough).

python -m pip install -r experiments/requirements.txt

Datasets

QM9 is automatically downloaded and processed when running the respective experiment.

The N-body datasets have to be generated locally from the directory experiments/nbody/data (it will take some time, especially n-body gravity)

Charged dataset (5 bodies, 10000 training samples)

python3 -u generate_dataset.py --simulation=charged --seed=43

Gravity dataset (100 bodies, 10000 training samples)

python3 -u generate_dataset.py --simulation=gravity --n-balls=100 --seed=43

Notes

On jax<=0.4.6, the jit-pjit merge can be deactivated making traning faster (on nbody). This looks like an issue with dataloading and the validation training loop implementation and it does not affect SEGNN.

export JAX_JIT_PJIT_API_MERGE=0

Usage

N-body (charged)

python validate.py --dataset=charged --epochs=200 --max-samples=3000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12

N-body (gravity)

python validate.py --dataset=gravity --epochs=100 --target=pos --max-samples=10000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12 --neighbours=5 --n-bodies=100

QM9

python validate.py --dataset=qm9 --epochs=1000 --target=alpha --lmax-hidden=2 --lmax-attributes=3 --layers=7 --units=128 --norm=instance --batch-size=128 --lr=5e-4 --weight-decay=1e-8 --lr-scheduling

(configurations used in validation)

Acknowledgments

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

segnn_jax-0.7.tar.gz (12.9 kB view details)

Uploaded Source

Built Distribution

segnn_jax-0.7-py3-none-any.whl (12.1 kB view details)

Uploaded Python 3

File details

Details for the file segnn_jax-0.7.tar.gz.

File metadata

  • Download URL: segnn_jax-0.7.tar.gz
  • Upload date:
  • Size: 12.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.4

File hashes

Hashes for segnn_jax-0.7.tar.gz
Algorithm Hash digest
SHA256 2ac2ae0a959d3128d13ebbfd74134e569f9d541e2954edee52db05631d76ba72
MD5 9a5bfd705326434d5d6c3d4a56e4c2ec
BLAKE2b-256 2a1eb96670220ece8a2633aaac867fdae0eb820418fde05580bbce9580581988

See more details on using hashes here.

File details

Details for the file segnn_jax-0.7-py3-none-any.whl.

File metadata

  • Download URL: segnn_jax-0.7-py3-none-any.whl
  • Upload date:
  • Size: 12.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.4

File hashes

Hashes for segnn_jax-0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 bd6b9e80b43d24dbce8f998d3f4e46aa188658f6c98ca7f54c4bcbb1b4846368
MD5 8cc30a8e9a89ff72133208505bce305c
BLAKE2b-256 b112b66c07021808d2b17b43ca341f93093c03805738c2e97eb79bb93e32f550

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page