Equivariant convolutional neural networks for the group E(3) of 3 dimensional rotations, translations, and mirrors.
Project description
e3nn-jax
Documentation
:boom: Warning :boom:
Please always check the ChangeLog for breaking changes.
Installation
To install the latest released version:
pip install --upgrade e3nn-jax
To install the latest GitHub version:
pip install git+https://github.com/e3nn/e3nn-jax.git
To install from a local copy for development, we recommend creating a virtual enviroment:
python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
To check that the tests are running:
pip install pytest
pytest e3nn_jax/_src/tensor_products_test.py
What is different from the PyTorch version?
- No more
shared_weights
andinternal_weights
inTensorProduct
. Extensive use ofjax.vmap
instead (see example below) - Support of python structure
IrrepsArray
that contains a contiguous version of the data and a list ofjnp.ndarray
for the data. This allows to avoid unnecessaryjnp.concatenante
followed by indexing to reverse the concatenation (even thatjax.jit
is probably able to unroll the concatenations) - Support of
None
in the list ofjnp.ndarray
to avoid unnecessary computation with zeros (basically imposing0 * x = 0
, which is not simplified by default by jax because0 * nan = nan
)
Examples
The examples are moved in the documentation.
Citing
@misc{e3nn_paper,
doi = {10.48550/ARXIV.2207.09453},
url = {https://arxiv.org/abs/2207.09453},
author = {Geiger, Mario and Smidt, Tess},
keywords = {Machine Learning (cs.LG), Artificial Intelligence (cs.AI), Neural and Evolutionary Computing (cs.NE), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {e3nn: Euclidean Neural Networks},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution 4.0 International}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
e3nn_jax-0.11.0.tar.gz
(87.6 kB
view hashes)
Built Distribution
e3nn_jax-0.11.0-py3-none-any.whl
(107.8 kB
view hashes)
Close
Hashes for e3nn_jax-0.11.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | dc49ea2fead084cd62e85a849a430dd0243d0178d8fe5c0d0226d97d0497df78 |
|
MD5 | b97581e9239c05cc2f29351c970c755f |
|
BLAKE2b-256 | f737a5d75a1ec80bf51c879e5c79c43639048856d151584a0de4f7bcae68130b |