Skip to main content

Fast Euclid equivariant operations for JAX

Project description

e3j

Euclid-equivariant operations and harmonic polynomials for JAX.

This library is intended as a faster and reliable substitute for the e3nn and e3x Euclid-equivariant backends. Our longer term goal is to replace slow components in equivariant MLIPs, as shipped with the mlip-jax repo.

For now, e3j is meant to operate alongside e3nn, on which it depends for a few non-critical operations (e.g. Irreps manipulations and Clebsch-Gordan coefficients). This dependency may be dropped in the future.

Project structure:

  • src/ : python source
  • lib/ : C++/CUDA source for the e3j_ops subpackage
    • cuda : custom kernel implementations
    • ffi : XLA and Python binding boilerplate

Installation

As a dependency

In the MLIP dependencies, managed with uv, we have the following set up:

[dependency-groups]
e3j = ["e3j"]
e3j_ops = ["e3j_ops"]

[tool.uv.sources]
e3j = {git="https://github.com/instadeepai/e3j", branch="main"}
#e3j = {git = "https://github.com/instadeepai/e3j", rev="fea221c4191204c87a960153738da38cc1923a6b"}

For development

Dependencies are managed with uv, after cloning the repository you can run one of:

# CPU install
uv sync --group cpu
# Existing CUDA 12 install with `e3j_ops` kernels:
uv sync --group cuda_local --extra ops
# Install CUDA 12 via pip and the `exp` group for benchmarks:
uv sync --group cuda --group exp --extra ops

You should now be able to run the tests with:

uv run pytest

Note: if you make a change to the lib/ code and would like to rebuild CUDA for the uv environment, you can run:

make uv # forces `uv cache clean` of the e3j_ops shared object
make pytest # pytest -m "e3j_ops" tests/test_ops

Building e3j_ops

The e3j.ops subpackage optionally provides JAX bindings to our sparse kernels via the XLA FFI.

There are currently two ways to build or test the e3j_ops shared object:

  1. Installing e3j with the ops extra, which internally relies on cmake and scikit-build.

    # Editable dev install
    uv sync --group cuda|cuda_local --extra ops
    # Dependency install
    pip install e3j[ops] # (once uploaded to PyPI, add git url/token for now)
    
  2. Building any of the recipes from the Makefile e.g.

    # Build the Pybind11 shared object
    make e3j_ops && export PYTHONPATH=$PWD/bin:$PYTHONPATH
    # Run a specific test
    make test_tensor_product
    # Build tests and run all e3j_ops dependent tests (C++ and Python)
    make test
    

Both (1) and (2) require that you have nvcc, g++ and eventually pybind11 installed. You should then be able to import e3j_ops from Python, exposing the raw XLA custom calls. Kernels should however be called from Python using the e3j.ops namespace API.

Documentation

Docs are hosted on GCP here, and deployed by the docs.yaml workflow.

They can also be built locally with:

cd docs && uv run make html
#=> file:///<path-to-e3j>/docs/_build/html/index.html`.

Running benchmarks

See the scripts directory for benchmarks of e3x, e3nn and e3j.

# run a collection of benchmarks
uv run python scripts/benchmark_main.py
# run individual benchmarks
uv run python -m scripts.benchmarks.benchmark_tensor_product
uv run python -m scripts.benchmarks.benchmark_harmonics
...

The main benchmarking script can be configured by the scripts/config.yaml file. For each of the E3Benchmark subclasses listed within,

  • the skip field control whether the benchmark case will be executed,
  • the grid field will be expanded into a cartesian product of hyper-parameters,
  • Hyper-parameters are converted by BENCHMARK_PARSERS[name] into a tuple of constructor arguments (e.g. l_max determines source and target representations, etc.)

TODO: cleanup benchmarks, to automate the definition of backward cases and get rid of the hacky BENCHMARK_PARSERS, that should be implemented in the class.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

e3j-0.1.0b0-py3-none-any.whl (77.9 kB view details)

Uploaded Python 3

File details

Details for the file e3j-0.1.0b0-py3-none-any.whl.

File metadata

  • Download URL: e3j-0.1.0b0-py3-none-any.whl
  • Upload date:
  • Size: 77.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.13

File hashes

Hashes for e3j-0.1.0b0-py3-none-any.whl
Algorithm Hash digest
SHA256 b2f5138a2aff27a03ba5ec9b6d3d6d63a4d02b68b46d61f4ea91132395eb486a
MD5 d593b2d87aabb58310d0b512a66da4c8
BLAKE2b-256 cefbede32366026c40a59074723cda81c4fe270ad3697657eec98720a1d25ebf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page