Usa numba in jax-compiled kernels.
Project description
numba4jax
A small experimental python package allowing you to use numba-jitted functions from within jax with no overhead.
This package uses the CFFI of Numba to expose the C Function pointer of your compiled function to XLA. It works both for CPU and GPU functions.
This package exports a single decorator @njit4jax
, which takes an argument, a function
or Tuple describing the output shape of the function itself.
See the brief example below.
import jax
import jax.numpy as jnp
from numba4jax import ShapedArray, njit4jax
def compute_type(*x):
return x[0]
@njit4jax(compute_type)
def test(args):
y, x, x2 = args
y[:] = x[:] + 1
z = jnp.ones((1, 2), dtype=float)
jax.make_jaxpr(test)(z, z)
print("output: ", test(z, z))
print("output: ", jax.jit(test)(z, z))
z = jnp.ones((2, 3), dtype=float)
print("output: ", jax.jit(test)(z, z))
z = jnp.ones((1, 3, 1), dtype=float)
print("output: ", jax.jit(test)(z, z))
Backend support
This package supports both the CPU and GPU backends of jax.
The GPU backend is only supported on linux, and is highly experimental.
It requires CUDA to be installed in a standard path.
CUDA is found through numba.cuda
, so you should first check that numba.cuda
works.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for numba4jax-0.0.12.post1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6261edd2109bc1051e25214df4f27015c71881c6ad4893db0d272776c2b3ab34 |
|
MD5 | 800f87d56baa608b6eccaeb3899db3bb |
|
BLAKE2b-256 | f75b8e8bbf1bf8f6b3969d457ad911aa8fe902172bb7d201744161304d21ad0c |