Equivariant convolutional neural networks for the group E(3) of 3 dimensional rotations, translations, and mirrors.
Project description
e3nn-jax
Documentation
import e3nn_jax as e3nn
# Create a random array made of a scalar (0e) and a vector (1o)
array = e3nn.normal("0e + 1o", jax.random.PRNGKey(0))
print(array)
# 1x0e+1x1o [ 1.8160863 -0.75488514 0.33988908 -0.53483534]
# Compute the norms
norms = e3nn.norm(array)
print(norms)
# 1x0e+1x0e [1.8160863 0.98560894]
# Compute the norm of the full array
total_norm = e3nn.norm(array, per_irrep=False)
print(total_norm)
# 1x0e [2.0662997]
# Compute the tensor product of the array with itself
tp = e3nn.tensor_square(array)
print(tp)
# 2x0e+1x1o+1x2e
# [ 1.9041989 0.25082085 -1.3709364 0.61726785 -0.97130704 0.40373924
# -0.25657722 -0.18037902 -0.18178469 -0.14190137]
:rocket: 44% faster than pytorch*
*Speed comparison done with a full model (MACE) during training (revMD-17) on a GPU (NVIDIA RTX A5000)
Please always check the CHANGELOG for breaking changes.
Installation
To install the latest released version:
pip install --upgrade e3nn-jax
To install the latest GitHub version:
pip install git+https://github.com/e3nn/e3nn-jax.git
Need Help?
Ask a question in the discussions tab.
What is different from the PyTorch version?
The main difference is the presence of the class IrrepsArray
.
IrrepsArray
contains the irreps (Irreps
) along with the data array.
Citing
@misc{e3nn_paper,
doi = {10.48550/ARXIV.2207.09453},
url = {https://arxiv.org/abs/2207.09453},
author = {Geiger, Mario and Smidt, Tess},
keywords = {Machine Learning (cs.LG), Artificial Intelligence (cs.AI), Neural and Evolutionary Computing (cs.NE), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {e3nn: Euclidean Neural Networks},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution 4.0 International}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
e3nn-jax-0.20.5.tar.gz
(561.4 kB
view hashes)
Built Distribution
e3nn_jax-0.20.5-py3-none-any.whl
(165.6 kB
view hashes)
Close
Hashes for e3nn_jax-0.20.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e9437bee3a2df04e0932f62a78ffbda81e85fbe26b4554a864f70a86094bf0e5 |
|
MD5 | 2f4e8981cb7a11cefd15aff2140a1d56 |
|
BLAKE2b-256 | e1ef03a57612d158566809d1686c9b79c2cf07d47b4bb598fff514fa9d4a0180 |