Fast Euclid equivariant operations for JAX
Project description
CUDA/C++ source
- lib : standalone library for custom CUDA kernels
- ffi : C++ binding code, using
pybind11and the legacy XLA custom call API.
:point_right: Note that the following is subject to change as we migrate to the new XLA ffi api.
lib
To keep the cognitive load of foreign bindings a minimum, we suggest to that every custom kernel's code exposes a common interface, see e.g. scatter_add.cuh for detailed signatures. The interface consists of:
- a namespace
e3j::op_namewithin the enclosinge3jnamespace, - a
structdefinitione3j::op_name::Params, - a
__global__CUDA functione3j::op_name::kernel, - a host wrapper
e3j::op_name::launch
The Params will serve as target type for XLA's opaque parameter UnpackDescriptor, see below.
Update (June 25)
Since we rely on dynamic linking to operate with XLA, the external headers from xla/ffi/api have been copied:
api.hc_api.hffi.h
Note that the XLA FFI API is not 100% stable yet:
WARNING: XLA FFI in under construction and currently does not provide any backward compatibility guarantees. Once we reach a point when we are reasonably confident that we got all APIs right, we will define
XLA_FFI_API_MAJORandXLA_FFI_API_MINORAPI versions and will start providing API and ABI backward compatibility.
ffi
NOTE: Within e3j, the goal is to use the two helpers defined in kernel_helpers.h generically, e.g.
- PackDescriptor<Params> (Params p) -> std::string(opaque, opaque_len)
- UnPackDescriptor<Params> (char *opaque, size_t opaque_len) -> Params p
so as to pass kernel hyperparameters to the XLA custom call. To make the
FFI and XLA binding boilerplate generic, the current to strategy is to
let any custom kernel be defined in its own namespace within e3j, with
its own kernel and launch functions:
namespace e3j { namespace op_name {
struct Params;
template <typename T>
__global__ void kernel<T> (T *a, ..., Params p);
template <typename T>
void launch (T *a, ..., Params p, cudaStream_t stream);
...
}}
This way, it is straightforward for the FFI-directed e3j_ops namespace
to define the XLA custom call (without cognitive overload) as:
namespace e3j_ops {
void op_name (
cudaStream_t,
void **buffers,
char *opaque,
size_t opaque_len
){
e3j::op_name::Params params =
UnpackDescriptor(opaque, opaque_len);
e3j::op_name::launch(
buffers[0],
...,
params,
stream
);
}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distributions
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file e3j_ops-0.1.0b0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.
File metadata
- Download URL: e3j_ops-0.1.0b0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
- Upload date:
- Size: 1.2 MB
- Tags: CPython 3.13, manylinux: glibc 2.24+ x86-64, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5682f896b7ff57fc1851e17d5f5c0d78515fee34b767cbb28b3f02f76006d193
|
|
| MD5 |
fbb6f9e47249efff7fb2d9df6f676d2a
|
|
| BLAKE2b-256 |
5f4668cc14a5e3413d73b4ac4471369a11cbb43b9619b6f500feff773c58c07a
|
File details
Details for the file e3j_ops-0.1.0b0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.
File metadata
- Download URL: e3j_ops-0.1.0b0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
- Upload date:
- Size: 1.2 MB
- Tags: CPython 3.12, manylinux: glibc 2.24+ x86-64, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b22694935b0a6773d4148609cb4a0a5df6aed38e2a5735d94d77657064524094
|
|
| MD5 |
cdd45d4d0480b35dbdff80b74b5eafa8
|
|
| BLAKE2b-256 |
6f3192ebc5048d6090ab8e65974165263f45003644232ec997e88ff8d2791e6a
|
File details
Details for the file e3j_ops-0.1.0b0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.
File metadata
- Download URL: e3j_ops-0.1.0b0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
- Upload date:
- Size: 1.2 MB
- Tags: CPython 3.11, manylinux: glibc 2.24+ x86-64, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ba5cf702d7341c134015649e2b01102e29bc534933766873c7f17731a9cc3c9c
|
|
| MD5 |
9e2ee25aafd06f43d128f8ec2f1a863d
|
|
| BLAKE2b-256 |
b860af47c7b1eb5fbe6a15ea9022f8c71e42d4e6ff3ed355ac8e06bb9cc5265d
|