E(3) GNN in jax
Project description
E(n) Equivariant GNN in jax
Reimplementation of EGNN in jax. Original work by Victor Garcia Satorras, Emiel Hogeboom and Max Welling.
Installation
python -m pip install egnn-jax
Or clone this repository and build locally
python -m pip install -e .
GPU support
Upgrade jax
to the gpu version
pip install --upgrade "jax[cuda]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Validation
N-body (charged) is included for validation from the original paper. Times are model only on batches of 100 graphs, in (global) single precision.
MSE | Inference [ms]* | |
---|---|---|
torch (original) | .0071 | 8.27 |
jax (ours) | .011 | 0.94 |
* remeasured (Quadro RTX 4000)
Validation install
The N-Body experiments are only included in the github repo, so it needs to be cloned first.
git clone https://github.com/gerkone/egnn-jax
They are adapted from the original implementation, so additionally torch
and torch_geometric
are needed (cpu versions are enough).
pip3 install torch==1.12.1 --extra-index-url https://download.pytorch.org/whl/cpu
python -m pip install -r nbody/requirements.txt
Valdation usage
The charged N-body dataset has to be locally generated in the directory /nbody/data.
python3 -u generate_dataset.py --num-train=3000
Then, the model can be trained and evaluated (from the repo root) with
python main.py --epochs=500 --batch-size=100 --lr=1e-4 --weight-decay=1e-8
Acknowledgements
This implementation heavily borrows from the original pytorch code.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file egnn_jax-0.2.tar.gz
.
File metadata
- Download URL: egnn_jax-0.2.tar.gz
- Upload date:
- Size: 5.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.11.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 935c3be1b3223a1b43df3d7557bc4c6ebca168406aca4be7d63d2eb8c82d9ecd |
|
MD5 | c61ad9069f470fdfda978515c54ffc87 |
|
BLAKE2b-256 | 9da1f97ad3155330067450dc5b4004dac8337b739265d9bf7a75403124d27e90 |
File details
Details for the file egnn_jax-0.2-py3-none-any.whl
.
File metadata
- Download URL: egnn_jax-0.2-py3-none-any.whl
- Upload date:
- Size: 6.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.11.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0ad68a5901025c459e4dd08e6739855b0478bd09be4329f0fcffb8ca6d957f7f |
|
MD5 | 0b7f5a4c11950f65eb6d4d2c0e259005 |
|
BLAKE2b-256 | c6b4139aeb264a8b45bab8bf1028014c43ca41d3fa764eb5d22ccd62081b64ba |