Skip to main content

Fast Euclid equivariant operations for JAX

Project description

CUDA/C++ source

:point_right: Note that the following is subject to change as we migrate to the new XLA ffi api.

lib

To keep the cognitive load of foreign bindings a minimum, we suggest to that every custom kernel's code exposes a common interface, see e.g. scatter_add.cuh for detailed signatures. The interface consists of:

  • a namespace e3j::op_name within the enclosing e3j namespace,
  • a struct definition e3j::op_name::Params,
  • a __global__ CUDA function e3j::op_name::kernel,
  • a host wrapper e3j::op_name::launch

The Params will serve as target type for XLA's opaque parameter UnpackDescriptor, see below.

Update (June 25)

Since we rely on dynamic linking to operate with XLA, the external headers from xla/ffi/api have been copied:

  • api.h
  • c_api.h
  • ffi.h

Note that the XLA FFI API is not 100% stable yet:

WARNING: XLA FFI in under construction and currently does not provide any backward compatibility guarantees. Once we reach a point when we are reasonably confident that we got all APIs right, we will define XLA_FFI_API_MAJOR and XLA_FFI_API_MINOR API versions and will start providing API and ABI backward compatibility.

ffi

NOTE: Within e3j, the goal is to use the two helpers defined in kernel_helpers.h generically, e.g.

- PackDescriptor<Params>	(Params p) -> std::string(opaque, opaque_len)
- UnPackDescriptor<Params>	(char *opaque, size_t opaque_len) -> Params p

so as to pass kernel hyperparameters to the XLA custom call. To make the FFI and XLA binding boilerplate generic, the current to strategy is to let any custom kernel be defined in its own namespace within e3j, with its own kernel and launch functions:

namespace e3j { namespace op_name {

	struct Params;

	template <typename T>
	__global__ void kernel<T> (T *a, ..., Params p);

	template <typename T>
	void launch (T *a, ..., Params p, cudaStream_t stream);

	...

}}

This way, it is straightforward for the FFI-directed e3j_ops namespace to define the XLA custom call (without cognitive overload) as:

namespace e3j_ops {

	void op_name (
		cudaStream_t,
		void **buffers,
		char *opaque,
		size_t opaque_len
	){

		e3j::op_name::Params params =
			UnpackDescriptor(opaque, opaque_len);

		e3j::op_name::launch(
			buffers[0],
			...,
			params,
			stream
		);
	}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

e3j_ops-0.1.0b0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (1.2 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

e3j_ops-0.1.0b0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (1.2 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

e3j_ops-0.1.0b0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (1.2 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

File details

Details for the file e3j_ops-0.1.0b0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for e3j_ops-0.1.0b0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 5682f896b7ff57fc1851e17d5f5c0d78515fee34b767cbb28b3f02f76006d193
MD5 fbb6f9e47249efff7fb2d9df6f676d2a
BLAKE2b-256 5f4668cc14a5e3413d73b4ac4471369a11cbb43b9619b6f500feff773c58c07a

See more details on using hashes here.

File details

Details for the file e3j_ops-0.1.0b0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for e3j_ops-0.1.0b0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 b22694935b0a6773d4148609cb4a0a5df6aed38e2a5735d94d77657064524094
MD5 cdd45d4d0480b35dbdff80b74b5eafa8
BLAKE2b-256 6f3192ebc5048d6090ab8e65974165263f45003644232ec997e88ff8d2791e6a

See more details on using hashes here.

File details

Details for the file e3j_ops-0.1.0b0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for e3j_ops-0.1.0b0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 ba5cf702d7341c134015649e2b01102e29bc534933766873c7f17731a9cc3c9c
MD5 9e2ee25aafd06f43d128f8ec2f1a863d
BLAKE2b-256 b860af47c7b1eb5fbe6a15ea9022f8c71e42d4e6ff3ed355ac8e06bb9cc5265d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page