Skip to main content

Fast Euclid equivariant operations for JAX

Project description

⚙️ e3j-ops

This package contains the CUDA/C++ source for e3j. The Python bindings to XLA handlers are bundled into the e3j_ops shared object, which the main e3j package wraps in custom JAX primitives via the ffi_call API.

Note: The e3j_ops ABI should not be considered stable for now, but considered private within the e3j Python API.

Building from source

The CMake build recipes are defined in CMakeLists.txt. The Makefile of e3j also defines fine-graind recipes to build and test individual CUDA/C++ objects.

Project structure

The source is organized as follows:

  • cuda : CUDA kernel implementations
  • ffi : XLA-FFI handlers declarations and pybind11 module definition
  • xla : vendored XLA FFI headers
  • tests : C++ kernel tests

CUDA/C++ source

Each operation is defined in its own namespace following a common interface (see e.g. tensor_product.cuh for detailed signatures):

  • a namespace e3j::op_name within the enclosing e3j namespace,
  • a struct e3j::op_name::Params for passing hyperparameters,
  • a __global__ CUDA function e3j::op_name::kernel(),
  • a __host__ launcher e3j::op_name::launch()

Although object-oriented patterns are hard to mix with __device__ code, this name-based ABI is subject to change.

namespace e3j {
namespace op_name {

    struct Params;

    template <typename Idx, typename Val>
    __global__ void kernel (Params p, ...);

    template <typename Idx, typename Val>
    e3j::Error launch (..., Params p, cudaStream_t stream);

} // namespace op_name
} // namespace e3j

Most kernels are templated across a range of index and/or value data types. Some e3j primitives automatically dispatch to narrow index dtypes (uint8 / uint16) when feature spaces dimensions are small enough. Half-precision float16 arithmetic for values is not yet supported, but planned for a soon upcoming release.

XLA-FFI handlers

The FFI layer uses the XLA FFI API to bind kernel launchers as XLA custom calls. Each handler is registered with XLA_FFI_DEFINE_HANDLER, receiving typed buffers and attributes directly (no opaque descriptor packing):

xla::Error OpNameHandler(
    cudaStream_t stream,
    int32_t num_out,
    xla::AnyBuffer x,
    xla::Result<xla::AnyBuffer> out
) {
    // ... dtype dispatch ...
    return e3j::op_name::launch<Idx, Val>(
        x.typed_data<Idx>(),
        out->typed_data<Val>(),
        params, stream
    ).to_xla();
}

XLA_FFI_DEFINE_HANDLER(
    xla_op_name,
    OpNameHandler,
    xla::Ffi::Bind()
        .Ctx<xla::PlatformStream<cudaStream_t>>()
        .Attr<int32_t>("num_out")
        .Arg<xla::AnyBuffer>()
        .Ret<xla::AnyBuffer>()
);

Handlers are exposed to Python as PyCapsules via pyEncapsulateFunction in e3j_ops.h, and registered in the pybind11 module defined in e3j_ops.cpp.

Contributing

Although it is too early for e3j_ops to accept significant external contributions, bug reports or questions are very welcome via GitHub issues and discussions.

Citing

If you use e3j within your work, we kindly ask you to cite the following preprint:

@article{Peltre26-e3j,
    title   = {{E3J}: an Efficient and Open-Source Euclidean Equivariance Backend},
    author  = {Peltre, Olivier and Picard, Armand and Pichard, Adrien and Giacomoni, Luca and Braganca, Miguel and Heyraud, Valentin and Brunken, Christoph and Tilly, Jules},
    journal = {preprint},
    year    = {2026},
    url     = {(preprint)}
  }
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

e3j_ops-0.1.0b2-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (2.7 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

e3j_ops-0.1.0b2-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (2.7 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

e3j_ops-0.1.0b2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (2.7 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

e3j_ops-0.1.0b2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (2.7 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

File details

Details for the file e3j_ops-0.1.0b2-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for e3j_ops-0.1.0b2-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 2975fcfdb22869863cefa5f76b202c4e1d5a06bd65ccccececdcefe57e7ef1d6
MD5 6767ad31a22068563025bdd193e225e8
BLAKE2b-256 1ac39e2bbfff8cb672ab1307713feb54497bede7a719b38658cec8c98c1fa438

See more details on using hashes here.

File details

Details for the file e3j_ops-0.1.0b2-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for e3j_ops-0.1.0b2-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 3a1c302b86a75aea2419ef1c1a786c4f71071f6ed3022af55f313765756d6604
MD5 d58dc1d914f0858c21f1541ccdb9dc75
BLAKE2b-256 30351b6f6abb8547193dbd2aa59b2a05426077eeb388aab6dac459cd8cdff9c7

See more details on using hashes here.

File details

Details for the file e3j_ops-0.1.0b2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for e3j_ops-0.1.0b2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 1f74e6b00a058a63738ab13b7a4331e9bd6f4b7ceb187d171304c70da1d28879
MD5 13d3d7c8e932cbe6ae65992e9b70e1f5
BLAKE2b-256 683efc3364079962f2211aff0e7b2292d096beb1b59ee7236a38f7549829f2ad

See more details on using hashes here.

File details

Details for the file e3j_ops-0.1.0b2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for e3j_ops-0.1.0b2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 57e778a8c50092ecca580a9edd987bae194253a5890574ccba2b507b3bb75060
MD5 a4605bd15086a5942906d5b4c3742829
BLAKE2b-256 ba8115a738055b32fb7ba9236d0dcc00a996dea69f7064f3e4a7ed99b4f1b468

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page