Skip to main content

Bind any function written in another language to JAX with support for JVP/VJP/batching/jit compilation

Project description

JAXbind: Bind any function to JAX

JAXbind API documentation: nifty-ppl.github.io/JAXbind/ | Found a bug? github.com/NIFTy-PPL/JAXbind/issues | Need help? github.com/NIFTy-PPL/JAXbind/discussions

Summary

The existing interface in JAX for connecting fully differentiable custom code requires deep knowledge of JAX and its C++ backend. The aim of JAXbind is to drastically lower the burden of connecting custom functions implemented in other programming languages to JAX. Specifically, JAXbind provides an easy-to-use Python interface for defining custom, so-called JAX primitives. Via JAXbind, any function callable from Python can be exposed as a JAX primitive. JAXbind allows to interface the JAX function transformation engine with custom derivatives and batching rules, enabling all JAX transformations for the custom primitive. In contrast, the JAX built-in external callback interface also has a Python endpoint but the external callbacks cannot be fully integrated into the JAX transformation engine, as only the Jacobian-vector product or the vector-Jacobian product can be added but not both.

Automatic Differentiation and Code Example

Automatic differentiation is a core feature of JAX and often one of the main reasons for using it. Thus, it is essential that custom functions registered with JAX support automatic differentiation. In the following, we will outline which functions our package respectively JAX requires to enable automatic differentiation. For simplicity, we assume that we want to connect the nonlinear function $f(x_1,x_2) = x_1x_2^2$ to JAX. The JAXbind package expects the Python function for $f$ to take three positional arguments. The first argument, out, is a tuple into which the function results are written. The second argument is also a tuple containing the input to the function, in our case, $x_1$ and $x_2$. Via kwargs_dump, potential keyword arguments given to the later registered Jax primitive can be forwarded to f in serialized form.

import jaxbind

def f(out, args, kwargs_dump):
    kwargs = jaxbind.load_kwargs(kwargs_dump)
    x1, x2 = args
    out[0][()] = x1 * x2**2

JAX's automatic differentiation engine can compute the Jacobian-vector product jvp and vector-Jacobian product vjp of JAX primitives. The Jacobian-vector product in JAX is a function applying the Jacobian of $f$ at a position $x$ to a tangent vector. In mathematical nomenclature this operation is called the pushforward of $f$ and can be denoted as $\partial f(x): T_x X \mapsto T_{f(x)} Y$, with $T_x X$ and $T_{f(x)} Y$ being the tangent spaces of $X$ and $Y$ at the positions $x$ and $f(x)$. As the implementation of $f$ is not JAX native, JAX cannot automatically compute the jvp. Instead, an implementation of the pushforward has to be provided, which JAXbind will register as the jvp of the JAX primitive of $f$. For our example, this Jacobian-vector-product function is given by $\partial f(x_1,x_2)(dx_1,dx_2) = x_2^2dx_1 + 2x_1x_2dx_2$.

def f_jvp(out, args, kwargs_dump):
    kwargs = jaxbind.load_kwargs(kwargs_dump)
    x1, x2, dx1, dx2 = args
    out[0][()] = x2**2 * dx1 + 2 * x1 * x2 * dx2

The vector-Jacobian product vjp in JAX is the linear transpose of the Jacobian-vector product. In mathematical nomenclature this is the pullback $(\partial f(x))^{T}: T_{f(x)}Y \mapsto T_x X$ of $f$. Analogously to the jvp, the user has to implement this function as JAX cannot automatically construct it. For our example function, the vector-Jacobian product is $(\partial f(x_1,x_2))^{T}(dy) = (x_2^2dy, 2x_1x_2dy)$.

def f_vjp(out, args, kwargs_dump):
    kwargs = jaxbind.load_kwargs(kwargs_dump)
    x1, x2, dy = args
    out[0][()] = x2**2 * dy
    out[1][()] = 2 * x1 * x2 * dy

To just-in-time compile the function, JAX needs to abstractly evaluate the code, i.e. it needs to be able to know the shape and dtype of the output of the custom function given only the shape and dtype of the input. We have to provide these abstract evaluation functions returning the output shape and dtype given an input shape and dtype for f as well as for the vjp application. The output shape of the jvp is identical to the output shape of f itself and does not need to be specified again. Due to the internals of JAX the abstract evaluation functions take normal keyword arguments and not serialized keyword arguments.

def f_abstract(*args, **kwargs):
    assert args[0].shape == args[1].shape
    return ((args[0].shape, args[0].dtype),)

def f_abstract_T(*args, **kwargs):
    return (
        (args[0].shape, args[0].dtype),
        (args[0].shape, args[0].dtype),
    )

We have now defined all ingredients necessary to register a JAX primitive for our function $f$ using the JAXbind package.

f_jax = jaxbind.get_nonlinear_call(
    f, (f_jvp, f_vjp), f_abstract, f_abstract_T
)

f_jax is a JAX primitive registered via the JAXbind package supporting all JAX transformations. We can now compute the jvp and vjp of the new JAX primitive and even jit-compile and batch it.

import jax
import jax.numpy as jnp

inp = (jnp.full((4,3), 4.), jnp.full((4,3), 2.))
tan = (jnp.full((4,3), 1.), jnp.full((4,3), 1.))
res, res_tan = jax.jvp(f_jax, inp, tan)

cotan = [jnp.full((4,3), 6.)]
res, f_vjp = jax.vjp(f_jax, *inp)
res_cotan = f_vjp(cotan)

f_jax_jit = jax.jit(f_jax)
res = f_jax_jit(*inp)

Higher Order Derivatives and Linear Functions

JAX supports higher order derivatives and can differentiate a jvp or vjp with respect to the position at which the Jacobian was taken. Similar to first derivatives, JAX can not automatically compute higher derivatives of a general function $f$ that is not natively implemented in JAX. Higher order derivatives would again need to be provided by the user. For many algorithms, first derivatives are sufficient, and higher order derivatives are often not implemented by the high-performance codes. Therefore, the current interface of JAXbind is, for simplicity, restricted to first derivatives. In the future, the interface could be easily expanded if specific use cases require higher order derivatives.

In scientific computing, linear functions such as, e.g., spherical harmonic transforms are widespread. If the function $f$ is linear, differentiation becomes trivial. Specifically for a linear function $f$, the pushforward respectively the jvp of $f$ is identical to $f$ itself and independent of the position at which it is computed. Expressed in formulas, $\partial f(x)(dx) = f(dx)$ if $f$ is linear in $x$. Analogously, the pullback respectively the vjp becomes independent of the initial position and is given by the linear transpose of $f$, thus $(\partial f(x))^{T}(dy) = f^T(dy)$. Also, all higher order derivatives can be expressed in terms of $f$ and its transpose. To make use of these simplifications, JAXbind provides a special interface for linear functions, supporting higher order derivatives, only requiring an implementation of the function and its transpose.

Demos and Documentation

Additional demos can be found in the demos folder. Specifically, there is a basic demo 01_linear_function.py showcasing the interface for linear functions and custom batching rules. 02_multilinear_function.py binds a multi-linear function as a JAX primitive. Finally, 03_nonlinear_function.py demonstrates the interface for non-linear functions and shows how to deal with fixed arguments, which cannot be differentiated. JAXbind provides bindings to parts of the functionality of the DUCC package. The DUCC bindings are also exposed as a webpage to showcase a real-world example of the usage of JAXbind. The documentation of the JAXbind API is available here.

Platforms

Currently, JAXbind only has CPU but no GPU support. With some expertise on Python bindings for GPU kernels adding GPU support should be fairly simple. The interfacing with the JAX automatic differentiation engine is identical for CPU and GPU. Contributions are welcome!

Installation

Binary wheels for JAXbind can be obtained and installed from PyPI via:

pip install jaxbind

To install JAXbind from source, clone the repository and install the package via pip.

git clone https://github.com/NIFTy-PPL/jaxbind.git
cd jaxbind
pip install .

Contributing

Contributions are highly appreciated! Please open an issue first if you think your PR changes current code substantially. Please format your code using black. PRs affecting the public API, including adding new features, should update the public documentation. If possible, add appropriate tests to your PR. Feel free to open a PR early on in the development process, we are happy to help in the development process and provide feedback along the way.

Licensing terms

All source code in this package is released under the 2-clause BSD license. All of JAXbind is distributed without any warranty.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxbind-1.1.0.tar.gz (28.4 kB view details)

Uploaded Source

Built Distributions

jaxbind-1.1.0-cp312-cp312-win_amd64.whl (76.1 kB view details)

Uploaded CPython 3.12 Windows x86-64

jaxbind-1.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (918.9 kB view details)

Uploaded CPython 3.12 manylinux: glibc 2.24+ x86-64 manylinux: glibc 2.28+ x86-64

jaxbind-1.1.0-cp312-cp312-macosx_11_0_arm64.whl (77.2 kB view details)

Uploaded CPython 3.12 macOS 11.0+ ARM64

jaxbind-1.1.0-cp312-cp312-macosx_10_14_x86_64.whl (78.6 kB view details)

Uploaded CPython 3.12 macOS 10.14+ x86-64

jaxbind-1.1.0-cp311-cp311-win_amd64.whl (75.9 kB view details)

Uploaded CPython 3.11 Windows x86-64

jaxbind-1.1.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (920.0 kB view details)

Uploaded CPython 3.11 manylinux: glibc 2.24+ x86-64 manylinux: glibc 2.28+ x86-64

jaxbind-1.1.0-cp311-cp311-macosx_11_0_arm64.whl (78.5 kB view details)

Uploaded CPython 3.11 macOS 11.0+ ARM64

jaxbind-1.1.0-cp311-cp311-macosx_10_14_x86_64.whl (79.5 kB view details)

Uploaded CPython 3.11 macOS 10.14+ x86-64

jaxbind-1.1.0-cp310-cp310-win_amd64.whl (74.9 kB view details)

Uploaded CPython 3.10 Windows x86-64

jaxbind-1.1.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (899.1 kB view details)

Uploaded CPython 3.10 manylinux: glibc 2.24+ x86-64 manylinux: glibc 2.28+ x86-64

jaxbind-1.1.0-cp310-cp310-macosx_11_0_arm64.whl (77.3 kB view details)

Uploaded CPython 3.10 macOS 11.0+ ARM64

jaxbind-1.1.0-cp310-cp310-macosx_10_14_x86_64.whl (78.2 kB view details)

Uploaded CPython 3.10 macOS 10.14+ x86-64

File details

Details for the file jaxbind-1.1.0.tar.gz.

File metadata

  • Download URL: jaxbind-1.1.0.tar.gz
  • Upload date:
  • Size: 28.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.1.dev0+g94f810c.d20240510 CPython/3.12.3

File hashes

Hashes for jaxbind-1.1.0.tar.gz
Algorithm Hash digest
SHA256 713d155322df7ed1a259f329ae7034aab312e69ed56f1c30adcd33ccaa6f69a7
MD5 2b97d4f1c9385e704cc82404d0a56fb3
BLAKE2b-256 58e364dd0d2ab4b3528287de4d4c3cc1643b1861e01325507604a6c1e4fe65aa

See more details on using hashes here.

File details

Details for the file jaxbind-1.1.0-cp312-cp312-win_amd64.whl.

File metadata

  • Download URL: jaxbind-1.1.0-cp312-cp312-win_amd64.whl
  • Upload date:
  • Size: 76.1 kB
  • Tags: CPython 3.12, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.1.dev0+g94f810c.d20240510 CPython/3.12.3

File hashes

Hashes for jaxbind-1.1.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 e355b7648250fec2c1d42c1ddf220d8ccae70782ee3293520d5c0771ca4917a1
MD5 93903a7de32f4d694c9ca3b60fbe3e05
BLAKE2b-256 8aebf56845127c749457e84363cc6afa7bf1deb3fcebdf00e9be48f77f4af81f

See more details on using hashes here.

File details

Details for the file jaxbind-1.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for jaxbind-1.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 7b28f0b4a614cb39f0faaa080c0c2476753cbe58c0d0d6c6e82e955c983a1f64
MD5 1967adc1498f75c9d1a52171e0ae4ee2
BLAKE2b-256 8c01e8d3e82b5cdd53e9e249e4cbce09859dc30f9b3a690cbb2a6547203576a4

See more details on using hashes here.

File details

Details for the file jaxbind-1.1.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for jaxbind-1.1.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 8fd64b5b17e949875ff5b1ddc4e0ecedb0c6de04df4cb3cfd61af613a3680d82
MD5 bceff549673638732552b63f1771a1f3
BLAKE2b-256 4ed5a2f1d2f7082e462ac08286d2f06cc3beefed42150bae8e121b54a1650c48

See more details on using hashes here.

File details

Details for the file jaxbind-1.1.0-cp312-cp312-macosx_10_14_x86_64.whl.

File metadata

File hashes

Hashes for jaxbind-1.1.0-cp312-cp312-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 68578ace3d77297ef4fed5c8e6a7cb33b519078cc7ec601f2f364fa6c041a3f7
MD5 c35b24a46c079b8f94468c918371bf88
BLAKE2b-256 31e2465fbf3587c23959466d309796aa18a491f622242698d3654f69e915bad2

See more details on using hashes here.

File details

Details for the file jaxbind-1.1.0-cp311-cp311-win_amd64.whl.

File metadata

  • Download URL: jaxbind-1.1.0-cp311-cp311-win_amd64.whl
  • Upload date:
  • Size: 75.9 kB
  • Tags: CPython 3.11, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.1.dev0+g94f810c.d20240510 CPython/3.12.3

File hashes

Hashes for jaxbind-1.1.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 a7627dd35e540eb31886fcd6e2095aae47abacdede7b50fe077d51e07d6d4035
MD5 c9fb09cab7b8a2b986c0bc42dcc31f65
BLAKE2b-256 36699e7f02abc53bf004fd05b3b044a7b1a605042e79c5682e56c7f289713c78

See more details on using hashes here.

File details

Details for the file jaxbind-1.1.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for jaxbind-1.1.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 0e55b7aa2041dda41c5cf8fb8d2bf3206fa8fccc2faa5eadd21cfdd3b5178b96
MD5 895008e733d47d21467e674fbd2f38c2
BLAKE2b-256 ae8360e17d3705a2887c5874f6fc13f2b52318977b4e74bb9c75ea3bd7974e26

See more details on using hashes here.

File details

Details for the file jaxbind-1.1.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for jaxbind-1.1.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 faf3b60db2f440a32c6109bdd7ed0b269387ff0d5f00e820316ddd2b0f630f99
MD5 9d14fd4d28f5786ce492c684d237642e
BLAKE2b-256 2702fbb4741d1685676f473f0659d48a73b3964ef637c96b4a9f4fea017636da

See more details on using hashes here.

File details

Details for the file jaxbind-1.1.0-cp311-cp311-macosx_10_14_x86_64.whl.

File metadata

File hashes

Hashes for jaxbind-1.1.0-cp311-cp311-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 828bec8eb85c840066fa0d43cde60501550112d251430ea3dfd858c5ad378e70
MD5 f4f7043746f2ce0e36e732ba29f760b0
BLAKE2b-256 79ba6aebc66e6d0d5a7bf085adcbd4e407d70233db7347446695d5d6354da27e

See more details on using hashes here.

File details

Details for the file jaxbind-1.1.0-cp310-cp310-win_amd64.whl.

File metadata

  • Download URL: jaxbind-1.1.0-cp310-cp310-win_amd64.whl
  • Upload date:
  • Size: 74.9 kB
  • Tags: CPython 3.10, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.1.dev0+g94f810c.d20240510 CPython/3.12.3

File hashes

Hashes for jaxbind-1.1.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 e3841004529bcf5e550ee36fad0c3529c40dfe1e780c6a585b31e710c86d371f
MD5 20a2045ade0cb4e10748f1244575138c
BLAKE2b-256 433688fd2470a025d4a3281a1494f995d410b591ff288b5347c1b35fbf5fc124

See more details on using hashes here.

File details

Details for the file jaxbind-1.1.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for jaxbind-1.1.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 da83d34e93d6e4863a2fc2de3971c468d3abd54804323f576a1c23222f30b068
MD5 4313a2658865c5151f432df0f0a0a337
BLAKE2b-256 1036a8007bca2fc6fe8fafaafb4b0fd4047adbbf9e6531eda8c3727fac323c28

See more details on using hashes here.

File details

Details for the file jaxbind-1.1.0-cp310-cp310-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for jaxbind-1.1.0-cp310-cp310-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 ba1b4579fcee8b3d4545aaeedfd63fb8c5538373a18c149e4eae4599d0f5167d
MD5 0036f0a8e249794f86a9c10232bd65b0
BLAKE2b-256 ba35042659eaf5a62b6c9d20d4539d24b508cff73c0a2c8de023cc621fe52218

See more details on using hashes here.

File details

Details for the file jaxbind-1.1.0-cp310-cp310-macosx_10_14_x86_64.whl.

File metadata

File hashes

Hashes for jaxbind-1.1.0-cp310-cp310-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 3f2566368a3dbd0a2f7e69ea043479823fe062aaf49aa10ae2abd87989edef1e
MD5 2c87ac71809d829c7500e6c228452dae
BLAKE2b-256 b3b530591b44de8e61baaea8533589848cffe8741e53afc82e54cb8090bcee85

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page