Skip to main content

E(3) GNN in jax

Project description

E(n) Equivariant GNN in jax

Reimplementation of EGNN in jax. Original work by Victor Garcia Satorras, Emiel Hogeboom and Max Welling.

Installation

python -m pip install egnn-jax

Or clone this repository and build locally

python -m pip install -e .

GPU support

Upgrade jax to the gpu version

pip install --upgrade "jax[cuda]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Validation

N-body (charged) is included for validation from the original paper. Times are model only on batches of 100 graphs, in (global) single precision.

MSE Inference [ms]*
torch (original) .0071 8.27
jax (ours) .011 0.94

* remeasured (Quadro RTX 4000)

Validation install

The N-Body experiments are only included in the github repo, so it needs to be cloned first.

git clone https://github.com/gerkone/egnn-jax

They are adapted from the original implementation, so additionally torch and torch_geometric are needed (cpu versions are enough).

pip3 install torch==1.12.1 --extra-index-url https://download.pytorch.org/whl/cpu
python -m pip install -r nbody/requirements.txt

Valdation usage

The charged N-body dataset has to be locally generated in the directory /nbody/data.

python3 -u generate_dataset.py --num-train=3000

Then, the model can be trained and evaluated (from the repo root) with

python main.py --epochs=500 --batch-size=100 --lr=1e-4 --weight-decay=1e-8

Acknowledgements

This implementation heavily borrows from the original pytorch code.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

egnn_jax-0.2.tar.gz (5.3 kB view details)

Uploaded Source

Built Distribution

egnn_jax-0.2-py3-none-any.whl (6.2 kB view details)

Uploaded Python 3

File details

Details for the file egnn_jax-0.2.tar.gz.

File metadata

  • Download URL: egnn_jax-0.2.tar.gz
  • Upload date:
  • Size: 5.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.11.2

File hashes

Hashes for egnn_jax-0.2.tar.gz
Algorithm Hash digest
SHA256 935c3be1b3223a1b43df3d7557bc4c6ebca168406aca4be7d63d2eb8c82d9ecd
MD5 c61ad9069f470fdfda978515c54ffc87
BLAKE2b-256 9da1f97ad3155330067450dc5b4004dac8337b739265d9bf7a75403124d27e90

See more details on using hashes here.

File details

Details for the file egnn_jax-0.2-py3-none-any.whl.

File metadata

  • Download URL: egnn_jax-0.2-py3-none-any.whl
  • Upload date:
  • Size: 6.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.11.2

File hashes

Hashes for egnn_jax-0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 0ad68a5901025c459e4dd08e6739855b0478bd09be4329f0fcffb8ca6d957f7f
MD5 0b7f5a4c11950f65eb6d4d2c0e259005
BLAKE2b-256 c6b4139aeb264a8b45bab8bf1028014c43ca41d3fa764eb5d22ccd62081b64ba

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page