Fast Euclid equivariant operations for JAX
Project description
e3j
Euclid-equivariant operations and harmonic polynomials for JAX.
This library is intended as a faster and reliable substitute for the e3nn and e3x Euclid-equivariant backends. Our longer term goal is to replace slow components in equivariant MLIPs, as shipped with the mlip-jax repo.
For now, e3j is meant to operate alongside e3nn,
on which it depends for a few non-critical operations (e.g. Irreps manipulations and Clebsch-Gordan coefficients).
This dependency may be dropped in the future.
Project structure:
- src/ : python source
- e3j/core : framework-agnostic implementations of equivariant operations
- e3j/linen : flax.linen.Module wrappers
- lib/ : C++/CUDA source for the
e3j_opssubpackage
Installation
As a dependency
In the MLIP dependencies, managed with uv, we have the following set up:
[dependency-groups]
e3j = ["e3j"]
e3j_ops = ["e3j_ops"]
[tool.uv.sources]
e3j = {git="https://github.com/instadeepai/e3j", branch="main"}
#e3j = {git = "https://github.com/instadeepai/e3j", rev="fea221c4191204c87a960153738da38cc1923a6b"}
For development
Dependencies are managed with uv, after cloning the repository you can run one of:
# CPU install
uv sync --group cpu
# Existing CUDA 12 install with `e3j_ops` kernels:
uv sync --group cuda_local --extra ops
# Install CUDA 12 via pip and the `exp` group for benchmarks:
uv sync --group cuda --group exp --extra ops
You should now be able to run the tests with:
uv run pytest
Note: if you make a change to the lib/ code and would like to rebuild CUDA for the uv environment, you can run:
make uv # forces `uv cache clean` of the e3j_ops shared object
make pytest # pytest -m "e3j_ops" tests/test_ops
Building e3j_ops
The e3j.ops subpackage optionally provides JAX bindings to our
sparse kernels via the XLA FFI.
There are currently two ways to build or test the e3j_ops shared object:
-
Installing
e3jwith theopsextra, which internally relies oncmakeandscikit-build.# Editable dev install uv sync --group cuda|cuda_local --extra ops # Dependency install pip install e3j[ops] # (once uploaded to PyPI, add git url/token for now)
-
Building any of the recipes from the Makefile e.g.
# Build the Pybind11 shared object make e3j_ops && export PYTHONPATH=$PWD/bin:$PYTHONPATH # Run a specific test make test_tensor_product # Build tests and run all e3j_ops dependent tests (C++ and Python) make test
Both (1) and (2) require that you have nvcc, g++ and eventually pybind11
installed. You should then be able to import e3j_ops from Python, exposing
the raw XLA custom calls. Kernels should however be called from Python
using the e3j.ops namespace API.
Documentation
Docs are hosted on GCP here, and deployed by the docs.yaml workflow.
They can also be built locally with:
cd docs && uv run make html
#=> file:///<path-to-e3j>/docs/_build/html/index.html`.
Running benchmarks
See the scripts directory for benchmarks of e3x, e3nn and e3j.
# run a collection of benchmarks
uv run python scripts/benchmark_main.py
# run individual benchmarks
uv run python -m scripts.benchmarks.benchmark_tensor_product
uv run python -m scripts.benchmarks.benchmark_harmonics
...
The main benchmarking script can be configured by the
scripts/config.yaml file.
For each of the E3Benchmark subclasses listed within,
- the
skipfield control whether the benchmark case will be executed, - the
gridfield will be expanded into a cartesian product of hyper-parameters, - Hyper-parameters are converted by
BENCHMARK_PARSERS[name]into a tuple of constructor arguments (e.g.l_maxdetermines source and target representations, etc.)
TODO: cleanup benchmarks, to automate the definition of backward cases and get rid of the hacky BENCHMARK_PARSERS, that should be implemented in the class.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file e3j-0.1.0b0-py3-none-any.whl.
File metadata
- Download URL: e3j-0.1.0b0-py3-none-any.whl
- Upload date:
- Size: 77.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b2f5138a2aff27a03ba5ec9b6d3d6d63a4d02b68b46d61f4ea91132395eb486a
|
|
| MD5 |
d593b2d87aabb58310d0b512a66da4c8
|
|
| BLAKE2b-256 |
cefbede32366026c40a59074723cda81c4fe270ad3697657eec98720a1d25ebf
|