Skip to main content

Fast Euclid equivariant operations for JAX

Project description

⚙️ e3j-ops

This package contains the CUDA/C++ source for e3j. The Python bindings to XLA handlers are bundled into the e3j_ops shared object, which the main e3j package wraps in custom JAX primitives via the ffi_call API.

Note: The e3j_ops ABI should not be considered stable for now, but considered private within the e3j Python API.

Building from source

The CMake build recipes are defined in CMakeLists.txt. The Makefile of e3j also defines fine-graind recipes to build and test individual CUDA/C++ objects.

Project structure

The source is organized as follows:

  • cuda : CUDA kernel implementations
  • ffi : XLA-FFI handlers declarations and pybind11 module definition
  • xla : vendored XLA FFI headers
  • tests : C++ kernel tests

CUDA/C++ source

Each operation is defined in its own namespace following a common interface (see e.g. tensor_product.cuh for detailed signatures):

  • a namespace e3j::op_name within the enclosing e3j namespace,
  • a struct e3j::op_name::Params for passing hyperparameters,
  • a __global__ CUDA function e3j::op_name::kernel(),
  • a __host__ launcher e3j::op_name::launch()

Although object-oriented patterns are hard to mix with __device__ code, this name-based ABI is subject to change.

namespace e3j {
namespace op_name {

    struct Params;

    template <typename Idx, typename Val>
    __global__ void kernel (Params p, ...);

    template <typename Idx, typename Val>
    e3j::Error launch (..., Params p, cudaStream_t stream);

} // namespace op_name
} // namespace e3j

Most kernels are templated across a range of index and/or value data types. Some e3j primitives automatically dispatch to narrow index dtypes (uint8 / uint16) when feature spaces dimensions are small enough. Half-precision float16 arithmetic for values is not yet supported, but planned for a soon upcoming release.

XLA-FFI handlers

The FFI layer uses the XLA FFI API to bind kernel launchers as XLA custom calls. Each handler is registered with XLA_FFI_DEFINE_HANDLER, receiving typed buffers and attributes directly (no opaque descriptor packing):

xla::Error OpNameHandler(
    cudaStream_t stream,
    int32_t num_out,
    xla::AnyBuffer x,
    xla::Result<xla::AnyBuffer> out
) {
    // ... dtype dispatch ...
    return e3j::op_name::launch<Idx, Val>(
        x.typed_data<Idx>(),
        out->typed_data<Val>(),
        params, stream
    ).to_xla();
}

XLA_FFI_DEFINE_HANDLER(
    xla_op_name,
    OpNameHandler,
    xla::Ffi::Bind()
        .Ctx<xla::PlatformStream<cudaStream_t>>()
        .Attr<int32_t>("num_out")
        .Arg<xla::AnyBuffer>()
        .Ret<xla::AnyBuffer>()
);

Handlers are exposed to Python as PyCapsules via pyEncapsulateFunction in e3j_ops.h, and registered in the pybind11 module defined in e3j_ops.cpp.

Contributing

Although it is too early for e3j_ops to accept significant external contributions, bug reports or questions are very welcome via GitHub issues and discussions.

Citing

If you use e3j within your work, we kindly ask you to cite the following preprint:

@article{Peltre26-e3j,
    title   = {{E3J}: an Efficient and Open-Source Euclidean Equivariance Backend},
    author  = {Peltre, Olivier and Picard, Armand and Pichard, Adrien and Giacomoni, Luca and Braganca, Miguel and Heyraud, Valentin and Brunken, Christoph and Tilly, Jules},
    journal = {preprint},
    year    = {2026},
    url     = {(preprint)}
  }
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

e3j_ops-0.1.0b1-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (1.2 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

e3j_ops-0.1.0b1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (1.2 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

e3j_ops-0.1.0b1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (1.2 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

e3j_ops-0.1.0b1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (1.2 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

File details

Details for the file e3j_ops-0.1.0b1-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for e3j_ops-0.1.0b1-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 5315364f85bf844f677727d38e782815b0ecf529a9532c16d35ee19ba3899959
MD5 76954dd18ab4370f14a774ec1e8a7ccc
BLAKE2b-256 2a1c050efce96ea16ff731852c42d8e15f4b20c7997d68c6967ad7a16b171cbe

See more details on using hashes here.

File details

Details for the file e3j_ops-0.1.0b1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for e3j_ops-0.1.0b1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 e77e7b0ea2d56e5c610339f84402fd8098c983cea04ad2644bd291f87e954a3f
MD5 2a79a64de32590912dece1fd307e3fbd
BLAKE2b-256 617f2f91751bacda6ff2aaf46fbea4c57dfa49ffaa930c8cfcedb5cd6bdc8fca

See more details on using hashes here.

File details

Details for the file e3j_ops-0.1.0b1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for e3j_ops-0.1.0b1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 73b0e319711cbd8266222f9aa010fb3d9b18a1a3340cf64319a118a17a9bd2d1
MD5 d172dc76747f2fcf6ea6f0ddfda93047
BLAKE2b-256 aec5ae8b01f8bfb522cb304ff17a96011cb06915d50830da211b1cc2e9093d3e

See more details on using hashes here.

File details

Details for the file e3j_ops-0.1.0b1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for e3j_ops-0.1.0b1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 f9ff1aaa490e595eb4633207b54303b0934fb5cd6363f4ca420cfe05d32c59c5
MD5 08ffdb53a7bf11fc96f0de2e8191151d
BLAKE2b-256 521578035375b7c0d7fddb001da75d1b6e3992cbae562eb145b4a6d763e01ff8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page