E(3) GNN in jax
Project description
E(n) Equivariant GNN in jax
Reimplementation of EGNN in jax. Original work by Victor Garcia Satorras, Emiel Hogeboom and Max Welling.
Installation
python -m pip install egnn-jax
Or clone this repository and build locally
git clone https://github.com/gerkone/egnn-jax
cd painn-jax
python -m pip install -e .
GPU support
Upgrade jax to the gpu version
pip install --upgrade "jax[cuda]==0.4.10" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Validation
N-body (charged) is included for validation from the original paper. Times are model only on batches of 100 graphs, in (global) single precision.
| MSE | Inference [ms]* | |
|---|---|---|
| torch (original) | .0071 | 8.27 |
| jax (ours) | .0093 | 0.94 |
* remeasured (Quadro RTX 4000)
Validation install
The N-Body experiments are only included in the github repo, so it needs to be cloned first.
git clone https://github.com/gerkone/egnn-jax
They are adapted from the original implementation, so additionally torch and torch_geometric are needed (cpu versions are enough).
python -m pip install -r nbody/requirements.txt
Valdation usage
The charged N-body dataset has to be locally generated in the directory /nbody/data.
python -u generate_dataset.py --num-train 3000 --seed 43 --sufix small
Then, the model can be trained and evaluated with
python validate.py --epochs=1000 --batch-size=100 --lr=1e-4 --weight-decay=1e-12
Acknowledgements
This implementation heavily borrows from the original pytorch code.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file egnn_jax-0.3.tar.gz.
File metadata
- Download URL: egnn_jax-0.3.tar.gz
- Upload date:
- Size: 5.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.11.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
143f957f76d79574c6dae476195a1f65e08e82b1f400cc4071ea65e7ea9765b1
|
|
| MD5 |
0398ab74178da4e21ceee23f3d92a2cc
|
|
| BLAKE2b-256 |
d40a0f7274851cb5d07ae73527ae105197f72b14a242f17b3f1aedbb9e1a866f
|
File details
Details for the file egnn_jax-0.3-py3-none-any.whl.
File metadata
- Download URL: egnn_jax-0.3-py3-none-any.whl
- Upload date:
- Size: 7.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.11.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
21abca26b84a7ee9a048444fdcbc6cc54f3833133c26d9ea70d3c9c03f17159b
|
|
| MD5 |
a01f6f6253e156889d3e0fe97061c5da
|
|
| BLAKE2b-256 |
d39a0a97608371395e01cb5487dfb4b8a8b9468aaf3a8ee6052df3db2bbc0cae
|