Skip to main content

Usa numba in jax-compiled kernels.

Project description

numba4jax

A small experimental python package allowing you to use numba-jitted functions from within jax with no overhead.

This package uses the CFFI of Numba to expose the C Function pointer of your compiled function to XLA. It works both for CPU and GPU functions.

This package exports a single decorator @njit4jax, which takes an argument, a function or Tuple describing the output shape of the function itself. See the brief example below.

import jax
import jax.numpy as jnp

from numba4jax import ShapedArray, njit4jax


def compute_type(*x):
    return x[0]


@njit4jax(compute_type)
def test(args):
    y, x, x2 = args
    y[:] = x[:] + 1


z = jnp.ones((1, 2), dtype=float)

jax.make_jaxpr(test)(z, z)

print("output: ", test(z, z))
print("output: ", jax.jit(test)(z, z))

z = jnp.ones((2, 3), dtype=float)
print("output: ", jax.jit(test)(z, z))

z = jnp.ones((1, 3, 1), dtype=float)
print("output: ", jax.jit(test)(z, z))

Backend support

This package supports both the CPU and GPU backends of jax. The GPU backend is only supported on linux, and is highly experimental. It requires CUDA to be installed in a standard path. CUDA is found through numba.cuda, so you should first check that numba.cuda works.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

numba4jax-0.0.14.tar.gz (12.0 kB view details)

Uploaded Source

Built Distribution

numba4jax-0.0.14-py3-none-any.whl (16.2 kB view details)

Uploaded Python 3

File details

Details for the file numba4jax-0.0.14.tar.gz.

File metadata

  • Download URL: numba4jax-0.0.14.tar.gz
  • Upload date:
  • Size: 12.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.0 CPython/3.12.4

File hashes

Hashes for numba4jax-0.0.14.tar.gz
Algorithm Hash digest
SHA256 a16911c3d3d1ac72cd6d9fdd003c285b4b86fe365ca072b8187c228c5011630f
MD5 9de84a067a64a4447dfd89c7d2a67d47
BLAKE2b-256 e2a4f97a263f88bcd6aed229214ee6508d44ae5dbb22bcd748d071fda9c3a54c

See more details on using hashes here.

File details

Details for the file numba4jax-0.0.14-py3-none-any.whl.

File metadata

  • Download URL: numba4jax-0.0.14-py3-none-any.whl
  • Upload date:
  • Size: 16.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.0 CPython/3.12.4

File hashes

Hashes for numba4jax-0.0.14-py3-none-any.whl
Algorithm Hash digest
SHA256 cd4a23b5e25a3a4fc5e9adb21ca06cb1cccaf07a31e0dfa979619bf0447d33c2
MD5 06a20dab06cffc8d8b6d64857e2338ac
BLAKE2b-256 49ded1a5b8df5efaeed0cafd79a9b32f893291274034761ef523de3452e2b123

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page