Skip to main content

E(3) GNN in jax

Project description

E(n) Equivariant GNN in jax

Reimplementation of EGNN in jax. Original work by Victor Garcia Satorras, Emiel Hogeboom and Max Welling.

Installation

python -m pip install egnn-jax

Or clone this repository and build locally

git clone https://github.com/gerkone/egnn-jax
cd painn-jax
python -m pip install -e .

GPU support

Upgrade jax to the gpu version

pip install --upgrade "jax[cuda]==0.4.10" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Validation

N-body (charged) is included for validation from the original paper. Times are model only on batches of 100 graphs, in (global) single precision.

MSE Inference [ms]*
torch (original) .0071 8.27
jax (ours) .0093 0.94

* remeasured (Quadro RTX 4000)

Validation install

The N-Body experiments are only included in the github repo, so it needs to be cloned first.

git clone https://github.com/gerkone/egnn-jax

They are adapted from the original implementation, so additionally torch and torch_geometric are needed (cpu versions are enough).

python -m pip install -r nbody/requirements.txt

Valdation usage

The charged N-body dataset has to be locally generated in the directory /nbody/data.

python -u generate_dataset.py --num-train 3000 --seed 43 --sufix small

Then, the model can be trained and evaluated with

python validate.py --epochs=1000 --batch-size=100 --lr=1e-4 --weight-decay=1e-12

Acknowledgements

This implementation heavily borrows from the original pytorch code.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

egnn_jax-0.3.tar.gz (5.8 kB view details)

Uploaded Source

Built Distribution

egnn_jax-0.3-py3-none-any.whl (7.0 kB view details)

Uploaded Python 3

File details

Details for the file egnn_jax-0.3.tar.gz.

File metadata

  • Download URL: egnn_jax-0.3.tar.gz
  • Upload date:
  • Size: 5.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.11.3

File hashes

Hashes for egnn_jax-0.3.tar.gz
Algorithm Hash digest
SHA256 143f957f76d79574c6dae476195a1f65e08e82b1f400cc4071ea65e7ea9765b1
MD5 0398ab74178da4e21ceee23f3d92a2cc
BLAKE2b-256 d40a0f7274851cb5d07ae73527ae105197f72b14a242f17b3f1aedbb9e1a866f

See more details on using hashes here.

File details

Details for the file egnn_jax-0.3-py3-none-any.whl.

File metadata

  • Download URL: egnn_jax-0.3-py3-none-any.whl
  • Upload date:
  • Size: 7.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.11.3

File hashes

Hashes for egnn_jax-0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 21abca26b84a7ee9a048444fdcbc6cc54f3833133c26d9ea70d3c9c03f17159b
MD5 a01f6f6253e156889d3e0fe97061c5da
BLAKE2b-256 d39a0a97608371395e01cb5487dfb4b8a8b9468aaf3a8ee6052df3db2bbc0cae

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page