Skip to main content

Bind any function written in another language to JAX with support for JVP/VJP/batching/jit compilation

Project description

JAXbind: Bind any function to JAX

JAXbind API documentation: nifty-ppl.github.io/JAXbind/ | Found a bug? github.com/NIFTy-PPL/JAXbind/issues | Need help? github.com/NIFTy-PPL/JAXbind/discussions

Summary

The existing interface in JAX for connecting fully differentiable custom code requires deep knowledge of JAX and its C++ backend. The aim of JAXbind is to drastically lower the burden of connecting custom functions implemented in other programming languages to JAX. Specifically, JAXbind provides an easy-to-use Python interface for defining custom, so-called JAX primitives. Via JAXbind, any function callable from Python can be exposed as a JAX primitive. JAXbind allows to interface the JAX function transformation engine with custom derivatives and batching rules, enabling all JAX transformations for the custom primitive. In contrast, the JAX built-in external callback interface also has a Python endpoint but the external callbacks cannot be fully integrated into the JAX transformation engine, as only the Jacobian-vector product or the vector-Jacobian product can be added but not both.

Automatic Differentiation and Code Example

Automatic differentiation is a core feature of JAX and often one of the main reasons for using it. Thus, it is essential that custom functions registered with JAX support automatic differentiation. In the following, we will outline which functions our package respectively JAX requires to enable automatic differentiation. For simplicity, we assume that we want to connect the nonlinear function $f(x_1,x_2) = x_1x_2^2$ to JAX. The JAXbind package expects the Python function for $f$ to take three positional arguments. The first argument, out, is a tuple into which the function results are written. The second argument is also a tuple containing the input to the function, in our case, $x_1$ and $x_2$. Via kwargs_dump, potential keyword arguments given to the later registered Jax primitive can be forwarded to f in serialized form.

import jaxbind

def f(out, args, kwargs_dump):
    kwargs = jaxbind.load_kwargs(kwargs_dump)
    x1, x2 = args
    out[0][()] = x1 * x2**2

JAX's automatic differentiation engine can compute the Jacobian-vector product jvp and vector-Jacobian product vjp of JAX primitives. The Jacobian-vector product in JAX is a function applying the Jacobian of $f$ at a position $x$ to a tangent vector. In mathematical nomenclature this operation is called the pushforward of $f$ and can be denoted as $\partial f(x): T_x X \mapsto T_{f(x)} Y$, with $T_x X$ and $T_{f(x)} Y$ being the tangent spaces of $X$ and $Y$ at the positions $x$ and $f(x)$. As the implementation of $f$ is not JAX native, JAX cannot automatically compute the jvp. Instead, an implementation of the pushforward has to be provided, which JAXbind will register as the jvp of the JAX primitive of $f$. For our example, this Jacobian-vector-product function is given by $\partial f(x_1,x_2)(dx_1,dx_2) = x_2^2dx_1 + 2x_1x_2dx_2$.

def f_jvp(out, args, kwargs_dump):
    kwargs = jaxbind.load_kwargs(kwargs_dump)
    x1, x2, dx1, dx2 = args
    out[0][()] = x2**2 * dx1 + 2 * x1 * x2 * dx2

The vector-Jacobian product vjp in JAX is the linear transpose of the Jacobian-vector product. In mathematical nomenclature this is the pullback $(\partial f(x))^{T}: T_{f(x)}Y \mapsto T_x X$ of $f$. Analogously to the jvp, the user has to implement this function as JAX cannot automatically construct it. For our example function, the vector-Jacobian product is $(\partial f(x_1,x_2))^{T}(dy) = (x_2^2dy, 2x_1x_2dy)$.

def f_vjp(out, args, kwargs_dump):
    kwargs = jaxbind.load_kwargs(kwargs_dump)
    x1, x2, dy = args
    out[0][()] = x2**2 * dy
    out[1][()] = 2 * x1 * x2 * dy

To just-in-time compile the function, JAX needs to abstractly evaluate the code, i.e. it needs to be able to know the shape and dtype of the output of the custom function given only the shape and dtype of the input. We have to provide these abstract evaluation functions returning the output shape and dtype given an input shape and dtype for f as well as for the vjp application. The output shape of the jvp is identical to the output shape of f itself and does not need to be specified again. Due to the internals of JAX the abstract evaluation functions take normal keyword arguments and not serialized keyword arguments.

def f_abstract(*args, **kwargs):
    assert args[0].shape == args[1].shape
    return ((args[0].shape, args[0].dtype),)

def f_abstract_T(*args, **kwargs):
    return (
        (args[0].shape, args[0].dtype),
        (args[0].shape, args[0].dtype),
    )

We have now defined all ingredients necessary to register a JAX primitive for our function $f$ using the JAXbind package.

f_jax = jaxbind.get_nonlinear_call(
    f, (f_jvp, f_vjp), f_abstract, f_abstract_T
)

f_jax is a JAX primitive registered via the JAXbind package supporting all JAX transformations. We can now compute the jvp and vjp of the new JAX primitive and even jit-compile and batch it.

import jax
import jax.numpy as jnp

inp = (jnp.full((4,3), 4.), jnp.full((4,3), 2.))
tan = (jnp.full((4,3), 1.), jnp.full((4,3), 1.))
res, res_tan = jax.jvp(f_jax, inp, tan)

cotan = [jnp.full((4,3), 6.)]
res, f_vjp = jax.vjp(f_jax, *inp)
res_cotan = f_vjp(cotan)

f_jax_jit = jax.jit(f_jax)
res = f_jax_jit(*inp)

Higher Order Derivatives and Linear Functions

JAX supports higher order derivatives and can differentiate a jvp or vjp with respect to the position at which the Jacobian was taken. Similar to first derivatives, JAX can not automatically compute higher derivatives of a general function $f$ that is not natively implemented in JAX. Higher order derivatives would again need to be provided by the user. For many algorithms, first derivatives are sufficient, and higher order derivatives are often not implemented by the high-performance codes. Therefore, the current interface of JAXbind is, for simplicity, restricted to first derivatives. In the future, the interface could be easily expanded if specific use cases require higher order derivatives.

In scientific computing, linear functions such as, e.g., spherical harmonic transforms are widespread. If the function $f$ is linear, differentiation becomes trivial. Specifically for a linear function $f$, the pushforward respectively the jvp of $f$ is identical to $f$ itself and independent of the position at which it is computed. Expressed in formulas, $\partial f(x)(dx) = f(dx)$ if $f$ is linear in $x$. Analogously, the pullback respectively the vjp becomes independent of the initial position and is given by the linear transpose of $f$, thus $(\partial f(x))^{T}(dy) = f^T(dy)$. Also, all higher order derivatives can be expressed in terms of $f$ and its transpose. To make use of these simplifications, JAXbind provides a special interface for linear functions, supporting higher order derivatives, only requiring an implementation of the function and its transpose.

Demos and Documentation

Additional demos can be found in the demos folder. Specifically, there is a basic demo 01_linear_function.py showcasing the interface for linear functions and custom batching rules. 02_multilinear_function.py binds a multi-linear function as a JAX primitive. Finally, 03_nonlinear_function.py demonstrates the interface for non-linear functions and shows how to deal with fixed arguments, which cannot be differentiated. JAXbind provides bindings to parts of the functionality of the DUCC package. The DUCC bindings are also exposed as a webpage to showcase a real-world example of the usage of JAXbind. The documentation of the JAXbind API is available here.

Platforms

Currently, JAXbind only has CPU but no GPU support. With some expertise on Python bindings for GPU kernels adding GPU support should be fairly simple. The interfacing with the JAX automatic differentiation engine is identical for CPU and GPU. Contributions are welcome!

Installation

Binary wheels for JAXbind can be obtained and installed from PyPI via:

pip install jaxbind

To install JAXbind from source, clone the repository and install the package via pip.

git clone https://github.com/NIFTy-PPL/jaxbind.git
cd jaxbind
pip install .

Contributing

Contributions are highly appreciated! Please open an issue first if you think your PR changes current code substantially. Please format your code using black. PRs affecting the public API, including adding new features, should update the public documentation. If possible, add appropriate tests to your PR. Feel free to open a PR early on in the development process, we are happy to help in the development process and provide feedback along the way.

Licensing terms

All source code in this package is released under the 2-clause BSD license. All of JAXbind is distributed without any warranty.

Citing JAXbind

To cite JAXbind, please use the citation provided below.

@article{jaxbind,
    title = {JAXbind: Bind any function to JAX},
    author = {Jakob Roth and Martin Reinecke and Gordian Edenhofer},
    year = {2024},
    journal = {Journal of Open Source Software},
    publisher = {The Open Journal},
    volume = {9},
    number = {98},
    pages = {6532},
    doi = {10.21105/joss.06532},
    url = {https://doi.org/10.21105/joss.06532},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxbind-1.3.0.tar.gz (48.9 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

jaxbind-1.3.0-cp313-cp313-win_amd64.whl (71.3 kB view details)

Uploaded CPython 3.13Windows x86-64

jaxbind-1.3.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (82.4 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

jaxbind-1.3.0-cp313-cp313-macosx_11_0_arm64.whl (87.8 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

jaxbind-1.3.0-cp312-cp312-win_amd64.whl (71.3 kB view details)

Uploaded CPython 3.12Windows x86-64

jaxbind-1.3.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (82.4 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

jaxbind-1.3.0-cp312-cp312-macosx_11_0_arm64.whl (87.8 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

jaxbind-1.3.0-cp311-cp311-win_amd64.whl (71.6 kB view details)

Uploaded CPython 3.11Windows x86-64

jaxbind-1.3.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (82.6 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

jaxbind-1.3.0-cp311-cp311-macosx_11_0_arm64.whl (88.7 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

jaxbind-1.3.0-cp310-cp310-win_amd64.whl (71.8 kB view details)

Uploaded CPython 3.10Windows x86-64

jaxbind-1.3.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (82.8 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

jaxbind-1.3.0-cp310-cp310-macosx_11_0_arm64.whl (89.0 kB view details)

Uploaded CPython 3.10macOS 11.0+ ARM64

File details

Details for the file jaxbind-1.3.0.tar.gz.

File metadata

  • Download URL: jaxbind-1.3.0.tar.gz
  • Upload date:
  • Size: 48.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for jaxbind-1.3.0.tar.gz
Algorithm Hash digest
SHA256 6707e654da0bb3eb939e4bfb50a87f32d401c4534cd16fd38706af9c59caee45
MD5 25fa27dd95a49a10eb05c530716b524d
BLAKE2b-256 df916d4fbe7a241ddb3fe2852093859951445dccef7874f3177f5344b09217b2

See more details on using hashes here.

File details

Details for the file jaxbind-1.3.0-cp313-cp313-win_amd64.whl.

File metadata

  • Download URL: jaxbind-1.3.0-cp313-cp313-win_amd64.whl
  • Upload date:
  • Size: 71.3 kB
  • Tags: CPython 3.13, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for jaxbind-1.3.0-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 9fcaa845b470e81e1038c4b2f79d96e2c3298f85738eac77d83bd5ee62e5f6e7
MD5 9d63fe6ba1536618102a7d98eed7eaa5
BLAKE2b-256 40559cc05ca623924d4e92ca999487f1c1844c87d049546ff68705bbca89ef47

See more details on using hashes here.

File details

Details for the file jaxbind-1.3.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for jaxbind-1.3.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 fc9fc69f885228d26ea3bb88800c57786bf1513780c6c270bcd3bffeb13e7b86
MD5 a58c45ff73228b3c93aa7411914b9862
BLAKE2b-256 6cf4ea5cc824fd14c0536b6dd41e4035f267d46ede842234dabc20ca76f7b41f

See more details on using hashes here.

File details

Details for the file jaxbind-1.3.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for jaxbind-1.3.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 cf263a1891e6bbbea1a07e281108c0fbe31e8a08e78b50761337a10a9e83772c
MD5 7318cb41d40499f22f4c46344f9a2273
BLAKE2b-256 a7ad8cb4d8a85dcf1fbfd70263850a518c0869aa2f734d0ea03edfbec2675cd2

See more details on using hashes here.

File details

Details for the file jaxbind-1.3.0-cp312-cp312-win_amd64.whl.

File metadata

  • Download URL: jaxbind-1.3.0-cp312-cp312-win_amd64.whl
  • Upload date:
  • Size: 71.3 kB
  • Tags: CPython 3.12, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for jaxbind-1.3.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 e01ccf06470426f80925a0d64bda20c36fb45e124d05ba544fff834b1ab4953f
MD5 7791fd2f65f3d937174a4fc573b60fd7
BLAKE2b-256 5c6b2a04b073e7a0ef2bc5c4cd16f5dc9ebd79bea64abf5ab31e0b0c721c9d03

See more details on using hashes here.

File details

Details for the file jaxbind-1.3.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for jaxbind-1.3.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 88cd0ce7c9b657d19031fd4703114f993828a0e65b1fdc5ae4cfaa85ec895c6f
MD5 aaef69a1bffd2afef656b011f5f94049
BLAKE2b-256 584259e8c0ad3062681043e97d359f1e995a9823e6885dfa1295be3758ef0cb6

See more details on using hashes here.

File details

Details for the file jaxbind-1.3.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for jaxbind-1.3.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 4b293734735011748622c34f3c26243822b9e5cba308e5898d3016dac7a6e31d
MD5 69eba4c3233f3cd9ec7710f871060a66
BLAKE2b-256 9a3984932a8bc573ff8c6d9d95a1d694818a3f404c46479881c9f6bc717630e3

See more details on using hashes here.

File details

Details for the file jaxbind-1.3.0-cp311-cp311-win_amd64.whl.

File metadata

  • Download URL: jaxbind-1.3.0-cp311-cp311-win_amd64.whl
  • Upload date:
  • Size: 71.6 kB
  • Tags: CPython 3.11, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for jaxbind-1.3.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 014b4d23bba63889c322a4284b4419488a69c6aa696b0f3f4d40ceaac57b148e
MD5 a83d9dc90e1d6257922ff87a5af0a1b7
BLAKE2b-256 0d595c8f7919a9c3f1b168af133dc8e877fa0f9d12d7e7d8b7d466c4d4bf1f01

See more details on using hashes here.

File details

Details for the file jaxbind-1.3.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for jaxbind-1.3.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 59d0491c9ef8bb60a1a0935a53a3b25856669fa66502ce2251e2f2944d1dcbd3
MD5 33adb63e8b4ce6743958410a7259ce0f
BLAKE2b-256 b67c9e59a5e1f15bd40a1540e532d268a1c6761f802ea72f4a34eeb5c485f75b

See more details on using hashes here.

File details

Details for the file jaxbind-1.3.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for jaxbind-1.3.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 658a4cc0addec2fb1524adaa9f030685b5b988f1feb516d3ff38df122f830bfa
MD5 784a10d9ef49e4786b05d4ba0cbb8207
BLAKE2b-256 2b8c22b72f87c44c0850a4e2d185919f87acfec17b1666e7c7ca663e79597319

See more details on using hashes here.

File details

Details for the file jaxbind-1.3.0-cp310-cp310-win_amd64.whl.

File metadata

  • Download URL: jaxbind-1.3.0-cp310-cp310-win_amd64.whl
  • Upload date:
  • Size: 71.8 kB
  • Tags: CPython 3.10, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for jaxbind-1.3.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 2e443925afcc3f53d682d251687419cb78786a4726a284a38ceb97c3ffbb5069
MD5 a5acffd8212d2e5a035d720a0f4f1597
BLAKE2b-256 659c93fedec4e2ae71fa4f168436da7c26a09dd94682fc1f8e9b34cdb53a5cce

See more details on using hashes here.

File details

Details for the file jaxbind-1.3.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for jaxbind-1.3.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 4d2e854718f1ae2017619b9c60e9ad5569f4788bfc4cf5aeb64953698f6f0fa9
MD5 ad1b30737836c3ab0dfe4e8b289063cb
BLAKE2b-256 18650cf48a2a06bc285dbadda5c1d2aad668045bc7d8d06e405200e13500be86

See more details on using hashes here.

File details

Details for the file jaxbind-1.3.0-cp310-cp310-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for jaxbind-1.3.0-cp310-cp310-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 b2e6445653abfa3608fb05508ffba49b4c63e81d130c307588391796626c217c
MD5 b1704d3982fc461cadcdbb48490ec881
BLAKE2b-256 9fd0a5e09196d90c03c671898f61ee5892e82d2338a1fa8b80f6751f758e683e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page