Skip to main content

Write Python. Run JAX

Project description

jaxify

Write Python. Run JAX.

CI Codecov Ruff ty uv Publish PyPI PyPI - Python Version

⚠️ jaxify is an experimental project under development
Feel free to test out and report any issues. Do not use in production.

jaxify lets you apply JAX transformations (like @jax.jit and/or @jax.vmap) to functions with Python control flow that JAX normally cannot compile, like if/elif/else statements depending on input values.

Installation

pip install jaxify

Getting started

import jax
import jax.numpy as jnp
from jaxify import jaxify

@jax.jit
@jax.vmap
@jaxify  # <-- Just add a @jaxify decorator
def absolute_value(x):
    if x >= 0:  # <-- If conditional in a JIT-compiled function!
        return x
    else:
        return -x

xs = jnp.arange(-1000, 1000)
ys = absolute_value(xs)  # <-- Runs at JAX speed!
print(ys)

How it works

@jaxify is a decorator that transforms Python functions by rewriting their abstract syntax tree (AST) to replace unsupported control flow constructs with JAX-compatible alternatives. It currently supports if/elif/else statements depending on input values, allowing you to write more natural Python code while still benefiting from JAX's performance boost.

When you decorate a function with @jaxify, it analyzes the function's source code, identifies control flow constructs, and rewrites them to use JAX's functional control flow primitives (like jax.lax.cond). The transformed function is then traceable by JAX, enabling you to apply JAX transformations like @jax.jit and @jax.vmap seamlessly.

Compatibility status

The following Python control flow constructs are currently supported within @jaxify-decorated functions:

Python construct Support status Notes
if / elif / else Should mostly work
if-else expressions ⚠️ Static values only
and / or ⚠️ Static values only. For dynamic values, use & or jnp.logical_and / | or jnp.logical_or instead
for loops Use jax.lax.fori_loop, jax.lax.scan, or jax.lax.while_loop instead
while loops Use jax.lax.while_loop instead
match-case ⚠️ Static values only. For dynamic values, use an if-elif-else chain or jax.lax.switch instead

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxify-0.0.3.tar.gz (3.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxify-0.0.3-py3-none-any.whl (4.4 kB view details)

Uploaded Python 3

File details

Details for the file jaxify-0.0.3.tar.gz.

File metadata

  • Download URL: jaxify-0.0.3.tar.gz
  • Upload date:
  • Size: 3.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for jaxify-0.0.3.tar.gz
Algorithm Hash digest
SHA256 8fce37b14335fc3d61be1beec53c40bff6077ef59a200b761b2bdc2bd88b76e5
MD5 6562ddc8ca5dd85e642c78f8962fc72b
BLAKE2b-256 c6e0a82bc3c5bde13f26ca10cefbbe75acf4a5d759275f4560cf625cf15e473a

See more details on using hashes here.

File details

Details for the file jaxify-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: jaxify-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 4.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for jaxify-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 71a70279c08f00f24912309e7d7aa46746bd0da19d79a6263c7202988a119f3e
MD5 3bf16b8f56812ada764e4a87966124ae
BLAKE2b-256 f023b0b5d5026fdebae0936b23a8668cb34605a82242ed6553d4ef9aadced490

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page