Convenience functions for adding arbitrary linear operations to JAX.
Project description
JAXbind: Easy bindings to JAX
Found a bug? github.com/NIFTy-PPL/JAXbind/issues | Need help? github.com/NIFTy-PPL/JAXbind/discussions
Summary
The existing interface in JAX for connecting custom code requires deep knowledge of JAX and its C++ backend.
The aim of JAXbind
is to drastically lower the burden of connecting custom functions implemented in other programming languages to JAX.
Specifically, JAXbind
provides an easy-to-use Python interface for defining custom, so-called JAX primitives supporting any JAX transformations.
Automatic Differentiation and Code Example
Automatic differentiation is a core feature of JAX and often one of the main reasons for using it.
Thus, it is essential that custom functions registered with JAX support automatic differentiation.
In the following, we will outline which functions our package respectively JAX requires to enable automatic differentiation.
For simplicity, we assume that we want to connect the nonlinear function $f(x_1,x_2) = x_1x_2^2$ to JAX.
The JAXbind
package expects the Python function for $f$ to take three positional arguments.
The first argument, out
, is a tuple
into which the function results are written.
The second argument is also a tuple
containing the input to the function, in our case, $x_1$ and $x_2$.
Via kwargs_dump
, potential keyword arguments given to the later registered Jax primitive can be forwarded to f
in serialized form.
import jaxbind
def f(out, args, kwargs_dump):
kwargs = jaxbind.load_kwargs(kwargs_dump)
x1, x2 = args
out[0][()] = x1 * x2**2
JAX's automatic differentiation engine can compute the Jacobian-vector product jvp
and vector-Jacobian product vjp
of JAX primitives.
The Jacobian-vector product in JAX is a function applying the Jacobian of $f$ at a position $x$ to a tangent vector.
In mathematical nomenclature this operation is called the pushforward of $f$ and can be denoted as $\partial f(x): T_x X \mapsto T_{f(x)} Y$, with $T_x X$ and $T_{f(x)} Y$ being the tangent spaces of $X$ and $Y$ at the positions $x$ and $f(x)$.
As the implementation of $f$ is not JAX native, JAX cannot automatically compute the jvp
.
Instead, an implementation of the pushforward has to be provided, which JAXbind
will register as the jvp
of the JAX primitive of $f$.
For our example, this Jacobian-vector-product function is given by $\partial f(x_1,x_2)(dx_1,dx_2) = x_2^2dx_1 + 2x_1x_2dx_2$.
def f_jvp(out, args, kwargs_dump):
kwargs = jaxbind.load_kwargs(kwargs_dump)
x1, x2, dx1, dx2 = args
out[0][()] = x2**2 * dx1 + 2 * x1 * x2 * dx2
The vector-Jacobian product vjp
in JAX is the linear transpose of the Jacobian-vector product.
In mathematical nomenclature this is the pullback $(\partial f(x))^{T}: T_{f(x)}Y \mapsto T_x X$ of $f$.
Analogously to the jvp
, the user has to implement this function as JAX cannot automatically construct it.
For our example function, the vector-Jacobian product is $(\partial f(x_1,x_2))^{T}(dy) = (x_2^2dy, 2x_1x_2dy)$.
def f_vjp(out, args, kwargs_dump):
kwargs = jaxbind.load_kwargs(kwargs_dump)
x1, x2, dy = args
out[0][()] = x2**2 * dy
out[1][()] = 2 * x1 * x2 * dy
To just-in-time compile the function, JAX needs to abstractly evaluate the code, i.e. it needs to be able to know the shape and dtype of the output of the custom function given only the shape and dtype of the input.
We have to provide these abstract evaluation functions returning the output shape and dtype given an input shape and dtype for f
as well as for the vjp
application.
The output shape of the jvp
is identical to the output shape of f
itself and does not need to be specified again.
Due to the internals of JAX the abstract evaluation functions take normal keyword arguments and not serialized keyword arguments.
def f_abstract(*args, **kwargs):
assert args[0].shape == args[1].shape
return ((args[0].shape, args[0].dtype),)
def f_abstract_T(*args, **kwargs):
return (
(args[0].shape, args[0].dtype),
(args[0].shape, args[0].dtype),
)
We have now defined all ingredients necessary to register a JAX primitive for our function $f$ using the JAXbind
package.
f_jax = jaxbind.get_nonlinear_call(
f, (f_jvp, f_vjp), f_abstract, f_abstract_T
)
f_jax
is a JAX primitive registered via the JAXbind
package supporting all JAX transformations.
We can now compute the jvp
and vjp
of the new JAX primitive and even jit-compile and batch it.
import jax
import jax.numpy as jnp
inp = (jnp.full((4,3), 4.), jnp.full((4,3), 2.))
tan = (jnp.full((4,3), 1.), jnp.full((4,3), 1.))
res, res_tan = jax.jvp(f_jax, inp, tan)
cotan = (jnp.full((4,3), 6.),)
res, f_vjp = jax.vjp(f_jax, *inp)
res_cotan = f_vjp(cotan)
f_jax_jit = jax.jit(f_jax)
res = f_jax_jit(*inp)
Higher Order Derivatives and Linear Functions
JAX supports higher order derivatives and can differentiate a jvp
or vjp
with respect to the position at which the Jacobian was taken.
Similar to first derivatives, JAX can not automatically compute higher derivatives of a general function $f$ that is not natively implemented in JAX.
Higher order derivatives would again need to be provided by the user.
For many algorithms, first derivatives are sufficient, and higher order derivatives are often not implemented by the high-performance codes.
Therefore, the current interface of JAXbind
is, for simplicity, restricted to first derivatives.
In the future, the interface could be easily expanded if specific use cases require higher order derivatives.
In scientific computing, linear functions such as, e.g., spherical harmonic transforms are widespread.
If the function $f$ is linear, differentiation becomes trivial.
Specifically for a linear function $f$, the pushforward respectively the jvp
of $f$ is identical to $f$ itself and independent of the position at which it is computed.
Expressed in formulas, $\partial f(x)(dx) = f(dx)$ if $f$ is linear in $x$.
Analogously, the pullback respectively the vjp
becomes independent of the initial position and is given by the linear transpose of $f$, thus $(\partial f(x))^{T}(dy) = f^T(dy)$.
Also, all higher order derivatives can be expressed in terms of $f$ and its transpose.
To make use of these simplifications, JAXbind
provides a special interface for linear functions, supporting higher order derivatives, only requiring an implementation of the function and its transpose.
Demos
Additional demos can be found in the demos folder. Specifically, there is a basic demo 01_linear_function.py showcasing the interface for linear functions and custom batching rules. 02_multilinear_function.py binds a multi-linear function as a JAX primitive. Finally, 03_nonlinear_function.py demonstrates the interface for non-linear functions and shows how to deal with fixed arguments, which cannot be differentiated.
Platforms
Currently, JAXbind
only has CPU but no GPU support.
With some expertise on Python bindings for GPU kernels adding GPU support should be fairly simple.
The Interfacing with the JAX automatic differentiation engine is identical for CPU and GPU.
Contributions are welcome!
Requirements
- Python >= 3.8
- only when compiling from source: pybind11
- only when compiling from source: a C++17-capable compiler, e.g.
g++
7 or laterclang++
- MSVC 2019 or later
- Intel
icpx
(oneAPI compiler series). (Note that the oldericpc
compilers are not supported.)
Installation
FIXME: PyPi Installation!
To install JAXbind from source clone the repository and install JAXbind via pip.
git clone https://github.com/NIFTy-PPL/jaxbind.git
cd jaxbind
pip install --user .
TODOs
- Paper
- final editing
- PiPy release
- README
Licensing terms
All source code in this package is released under the 2-clause BSD license. All of JAXbind is distributed without any warranty.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Hashes for jaxbind-0.1.0-cp312-cp312-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 534a032037d956cce4a4cdab5df15f40e4042673adade5cffdf9382eda55e039 |
|
MD5 | e99ba63ef91db3f1f1249b3b17d77b53 |
|
BLAKE2b-256 | acd015818a40c1f1c0b54b72e3fa9fd346eb6024f715cc1227070c9e700944bc |
Hashes for jaxbind-0.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0d3eae7b11a8a36481c9e50db400278f112f0963be8a0092906502be32f2f873 |
|
MD5 | f68816e30564c8bd7a6288ef45a935e0 |
|
BLAKE2b-256 | 81102c92f8a79be93eda316d53da1bfb8f19df50f963cb6eee27f15293ec75e4 |
Hashes for jaxbind-0.1.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c0c4fb43204e2c9b4a1136f0e3963ce5af140705a03a3b367c6c718010f50f49 |
|
MD5 | 085502e77078670da887706e59e493dd |
|
BLAKE2b-256 | a5fe3fa9e0f4d74bc164942152313dd14ab2961e2c00758a876ea99b55bd7b32 |
Hashes for jaxbind-0.1.0-cp312-cp312-macosx_10_14_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e2fe313f037698250ddaef02e50793c494ca179eb9d9870603044b3bcf12e849 |
|
MD5 | 4e35dd17a16cb98c285fac16201f3080 |
|
BLAKE2b-256 | c9533ca2e0beadc3f8f535ecd4dd7943c950c629061586338bdeffb65c977c80 |
Hashes for jaxbind-0.1.0-cp311-cp311-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f5c314ec8e7e03c38ac97bb7fd0deaa151052ec38b250d6ccb2298a6480c4656 |
|
MD5 | 411f5194f55aeb8ecf37ee66a3f93495 |
|
BLAKE2b-256 | 888ee0496933119e44b54e0d1594e79184a964195e6b01a7d91fb9c35c50bb5a |
Hashes for jaxbind-0.1.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 188f064606573720dec80601ac8198d90bfab3654f9578c661c01ac8a013e24b |
|
MD5 | 1b7caf07763530012c1488c0d59f1927 |
|
BLAKE2b-256 | 9afbd35305f8f6220e5b118740c1792454070ac285b8053e94e00e1d6c30fe23 |
Hashes for jaxbind-0.1.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0e7e1f5d3f8065cf535c61ad6e41a7ac1f9512f4ebe8a94d0d1421defa83ffe1 |
|
MD5 | 6c4e3f0a2a4786be66f4e1f4da93e103 |
|
BLAKE2b-256 | 6879ca3699e761202d546fe3fff3c9fc4b0ec729442526734fd1196e15d5a8f9 |
Hashes for jaxbind-0.1.0-cp311-cp311-macosx_10_14_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0ace5ac31f06238d48bc34a37ead0c9d39bca8878ddcd80ab3f762ed6caca966 |
|
MD5 | cb6284cea6d4a2f07f16ba9b69e49c62 |
|
BLAKE2b-256 | 7a2d5a5a5b95c2d1eb1d17376361c57b5c2b37fc4a97e5256442f039b6e59e6e |
Hashes for jaxbind-0.1.0-cp310-cp310-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 44a8f2608fe487034ec436b696f66302280248b91347e56fc5a872feedcc6bd8 |
|
MD5 | 5d9110a2065c07c15728f27730d73560 |
|
BLAKE2b-256 | f1237c172f9eeca236d8b1024ec569572f659bd5436e41a0ab4b04bdcad4570e |
Hashes for jaxbind-0.1.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | bbe35c555c68731e7ca8beb827985615fa385eb82aff0cc58b33e37cdcbbeb23 |
|
MD5 | f34b413d9beba11cc6580be0679a774d |
|
BLAKE2b-256 | 5dcfe49ab27b644ef1db81c3e17db38306b1dd6efb5ac712cfac280022c05a54 |
Hashes for jaxbind-0.1.0-cp310-cp310-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b4de2f60f1792f2e99502d51db74d5292877d572975eafbb83896166366103d7 |
|
MD5 | 08aca26aedc5faae986cc86638ae45cf |
|
BLAKE2b-256 | 6e96ee1120ffc79a7d33086c2b094b9cfd7a23a6afe04f9571d0901d0e06fcb5 |
Hashes for jaxbind-0.1.0-cp310-cp310-macosx_10_14_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fbde38223ca37e5b3f123f24062d910c64592558c4dcac2fcbadaf7dc80b5a0b |
|
MD5 | 2fbf7a791cf6c69880b0c051d1d68844 |
|
BLAKE2b-256 | ab980339fbed0bd00ab8e32c0ce7949fbcba0a8bf22c3ce6f7e3668f5634593f |
Hashes for jaxbind-0.1.0-cp39-cp39-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f237942e403640721c88ee6450d7dbf505f28ddc5802c2ce00c978cb3f799659 |
|
MD5 | ef38e75a885f9ec10b04c029454910b4 |
|
BLAKE2b-256 | 8e6e0c97878c1638e624b0d543b1236a5a9aaf466e96b044e5c8500e99b56089 |
Hashes for jaxbind-0.1.0-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | edbe0a76a4cf5e8584692797b3e05fa02a81e8bba1708db53b2a7f01980b8298 |
|
MD5 | 089cfbbe45eda6b4c586c6da4cfd1a50 |
|
BLAKE2b-256 | 570edff894f5b0d6c78366e0cb814a13aeb294e4b6b1f3e20b1db33257b4c684 |
Hashes for jaxbind-0.1.0-cp39-cp39-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | bfac5712e693109af698933a48bacd22f84cf31b545344c6c87faa9d21355a68 |
|
MD5 | bbf24b22dfda7e86d1359ce8a5fee718 |
|
BLAKE2b-256 | 3cdbbbc1720141d5428c4f40d6a507ae4c961881963557fc2193f5970a7683a1 |
Hashes for jaxbind-0.1.0-cp39-cp39-macosx_10_14_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a3f311afe27f93f9c611f1daa5bd991e361c31e61847e3a820a4a490e657a11a |
|
MD5 | 0da753d89ff10f5e6983433c845737a7 |
|
BLAKE2b-256 | 8d1e7bb6b9c152db5742b9457c188b5de04b69c5d7c3bf90bd30e398c85e7117 |