Skip to main content

Steerable E(3) GNN in jax

Project description

Steerable E(3) GNN in jax

Reimplementation of SEGNN in jax. Original work by Johannes Brandstetter, Rob Hesselink, Elise van der Pol, Erik Bekkers and Max Welling.

Installation

python -m pip install segnn-jax

Or clone this repository and build locally

python -m pip install -e .

GPU support

Upgrade jax to the gpu version

pip install --upgrade "jax[cuda]==0.4.1" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Validation

N-body (charged and gravity) and QM9 datasets are included for completeness from the original paper. The implementation is validated on all three of them, getting close results and considerably faster runtimes.

Results

torch (original) jax (ours)
MSE Inference [ms]* MSE Inference [ms]
charged (position) .0043 40.76 .0047 28.67
gravity (position) .265 392.20 .28 240.34
QM9 (alpha) .06 159.17 109.58**
* remeasured (Quadro RTX 4000), batch of 100 graphs, single precision

** padded

Validation install

The experiments are only included in the github repo, so it needs to be cloned first.

git clone https://github.com/gerkone/segnn-jax

They are adapted from the original implementation, so additionally torch and torch_geometric are needed (cpu versions are enough).

pip3 install torch==1.12.1 --extra-index-url https://download.pytorch.org/whl/cpu
python -m pip install -r experiments/requirements.txt

Datasets

QM9 is automatically downloaded and processed when running the respective experiment.

The N-body datasets have to be generated locally from the directory experiments/nbody/data (it will take some time, especially n-body gravity)

Charged dataset (5 bodies, 10000 training samples)

python3 -u generate_dataset.py --simulation=charged

Gravity dataset (100 bodies, 10000 training samples)

python3 -u generate_dataset.py --simulation=gravity --n-balls=100

Usage

N-body (charged)

python main.py --dataset=charged --epochs=200 --max-samples=3000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-4 --weight-decay=1e-8

N-body (gravity)

python main.py --dataset=gravity --epochs=100 --target=pos --max-samples=10000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=1e-4 --weight-decay=1e-8 --neighbours=5 --n-bodies=100

QM9

python main.py --dataset=qm9 --epochs=1000 --target=alpha --lmax-hidden=2 --lmax-attributes=3 --layers=7 --units=128 --norm=instance --batch-size=128 --lr=5e-4 --weight-decay=1e-8 --lr-scheduling

(configurations used in validation)

Acknowledgments

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

segnn_jax-0.3.tar.gz (9.7 kB view hashes)

Uploaded Source

Built Distribution

segnn_jax-0.3-py3-none-any.whl (9.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page