Fast Euclid equivariant operations for JAX
Project description
⚙️ e3j-ops
This package contains the CUDA/C++ source for e3j.
The Python bindings to XLA handlers are bundled into the e3j_ops shared object, which
the main e3j package wraps in custom JAX primitives via the ffi_call API.
Note: The
e3j_opsABI should not be considered stable for now, but considered private within thee3jPython API.
Building from source
The CMake build recipes are defined in CMakeLists.txt.
The Makefile of e3j also defines fine-graind recipes to build and test individual CUDA/C++ objects.
Project structure
The source is organized as follows:
- cuda : CUDA kernel implementations
- ffi : XLA-FFI handlers declarations and pybind11 module definition
- xla : vendored XLA FFI headers
- tests : C++ kernel tests
CUDA/C++ source
Each operation is defined in its own namespace following a common interface (see e.g. tensor_product.cuh for detailed signatures):
- a namespace
e3j::op_namewithin the enclosinge3jnamespace, - a struct
e3j::op_name::Paramsfor passing hyperparameters, - a
__global__CUDA functione3j::op_name::kernel(), - a
__host__launchere3j::op_name::launch()
Although object-oriented patterns are hard to mix with __device__ code,
this name-based ABI is subject to change.
namespace e3j {
namespace op_name {
struct Params;
template <typename Idx, typename Val>
__global__ void kernel (Params p, ...);
template <typename Idx, typename Val>
e3j::Error launch (..., Params p, cudaStream_t stream);
} // namespace op_name
} // namespace e3j
Most kernels are templated across a range of index and/or value data types. Some e3j primitives automatically dispatch to narrow index dtypes (uint8 / uint16) when feature spaces dimensions are small enough. Half-precision float16 arithmetic for values is not yet supported, but planned for a soon upcoming release.
XLA-FFI handlers
The FFI layer uses the XLA FFI API to bind kernel launchers
as XLA custom calls. Each handler is registered with XLA_FFI_DEFINE_HANDLER,
receiving typed buffers and attributes directly (no opaque descriptor packing):
xla::Error OpNameHandler(
cudaStream_t stream,
int32_t num_out,
xla::AnyBuffer x,
xla::Result<xla::AnyBuffer> out
) {
// ... dtype dispatch ...
return e3j::op_name::launch<Idx, Val>(
x.typed_data<Idx>(),
out->typed_data<Val>(),
params, stream
).to_xla();
}
XLA_FFI_DEFINE_HANDLER(
xla_op_name,
OpNameHandler,
xla::Ffi::Bind()
.Ctx<xla::PlatformStream<cudaStream_t>>()
.Attr<int32_t>("num_out")
.Arg<xla::AnyBuffer>()
.Ret<xla::AnyBuffer>()
);
Handlers are exposed to Python as PyCapsules via pyEncapsulateFunction
in e3j_ops.h, and registered in the
pybind11 module defined in e3j_ops.cpp.
Contributing
Although it is too early for e3j_ops to accept significant external contributions, bug reports or questions are very welcome via GitHub issues and discussions.
Citing
If you use e3j within your work, we kindly ask you to cite the following preprint:
@article{Peltre26-e3j,
title = {{E3J}: an Efficient and Open-Source Euclidean Equivariance Backend},
author = {Peltre, Olivier and Picard, Armand and Pichard, Adrien and Giacomoni, Luca and Braganca, Miguel and Heyraud, Valentin and Brunken, Christoph and Tilly, Jules},
journal = {preprint},
year = {2026},
url = {(preprint)}
}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distributions
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file e3j_ops-0.1.0b1-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.
File metadata
- Download URL: e3j_ops-0.1.0b1-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
- Upload date:
- Size: 1.2 MB
- Tags: CPython 3.14, manylinux: glibc 2.24+ x86-64, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5315364f85bf844f677727d38e782815b0ecf529a9532c16d35ee19ba3899959
|
|
| MD5 |
76954dd18ab4370f14a774ec1e8a7ccc
|
|
| BLAKE2b-256 |
2a1c050efce96ea16ff731852c42d8e15f4b20c7997d68c6967ad7a16b171cbe
|
File details
Details for the file e3j_ops-0.1.0b1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.
File metadata
- Download URL: e3j_ops-0.1.0b1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
- Upload date:
- Size: 1.2 MB
- Tags: CPython 3.13, manylinux: glibc 2.24+ x86-64, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e77e7b0ea2d56e5c610339f84402fd8098c983cea04ad2644bd291f87e954a3f
|
|
| MD5 |
2a79a64de32590912dece1fd307e3fbd
|
|
| BLAKE2b-256 |
617f2f91751bacda6ff2aaf46fbea4c57dfa49ffaa930c8cfcedb5cd6bdc8fca
|
File details
Details for the file e3j_ops-0.1.0b1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.
File metadata
- Download URL: e3j_ops-0.1.0b1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
- Upload date:
- Size: 1.2 MB
- Tags: CPython 3.12, manylinux: glibc 2.24+ x86-64, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
73b0e319711cbd8266222f9aa010fb3d9b18a1a3340cf64319a118a17a9bd2d1
|
|
| MD5 |
d172dc76747f2fcf6ea6f0ddfda93047
|
|
| BLAKE2b-256 |
aec5ae8b01f8bfb522cb304ff17a96011cb06915d50830da211b1cc2e9093d3e
|
File details
Details for the file e3j_ops-0.1.0b1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.
File metadata
- Download URL: e3j_ops-0.1.0b1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
- Upload date:
- Size: 1.2 MB
- Tags: CPython 3.11, manylinux: glibc 2.24+ x86-64, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f9ff1aaa490e595eb4633207b54303b0934fb5cd6363f4ca420cfe05d32c59c5
|
|
| MD5 |
08ffdb53a7bf11fc96f0de2e8191151d
|
|
| BLAKE2b-256 |
521578035375b7c0d7fddb001da75d1b6e3992cbae562eb145b4a6d763e01ff8
|