Write Python. Run JAX
Project description
jaxify
Write Python. Run JAX.
| ⚠️ jaxify is an experimental project under development |
|---|
| Feel free to test out and report any issues. Do not use in production. |
jaxify lets you apply JAX transformations (like @jax.jit and/or @jax.vmap) to functions with Python control flow that JAX normally cannot compile, like if/elif/else statements depending on input values.
Installation
pip install jaxify
Getting started
import jax
import jax.numpy as jnp
from jaxify import jaxify
@jax.jit
@jax.vmap
@jaxify # <-- Just add a @jaxify decorator
def absolute_value(x):
if x >= 0: # <-- If conditional in a JIT-compiled function!
return x
else:
return -x
xs = jnp.arange(-1000, 1000)
ys = absolute_value(xs) # <-- Runs at JAX speed!
print(ys)
How it works
@jaxify is a decorator that transforms Python functions by rewriting their abstract syntax tree (AST) to replace unsupported control flow constructs with JAX-compatible alternatives. It currently supports if/elif/else statements depending on input values, allowing you to write more natural Python code while still benefiting from JAX's performance boost.
When you decorate a function with @jaxify, it analyzes the function's source code, identifies control flow constructs, and rewrites them to use JAX's functional control flow primitives (like jax.lax.cond). The transformed function is then traceable by JAX, enabling you to apply JAX transformations like @jax.jit and @jax.vmap seamlessly.
Compatibility status
The following Python control flow constructs are currently supported within @jaxify-decorated functions:
| Python construct | Support status | Notes |
|---|---|---|
if / elif / else |
✅ | Should mostly work |
if-else expressions |
⚠️ | Static values only |
and / or |
⚠️ | Static values only. For dynamic values, use & or jnp.logical_and / | or jnp.logical_or instead |
for loops |
❌ | Use jax.lax.fori_loop, jax.lax.scan, or jax.lax.while_loop instead |
while loops |
❌ | Use jax.lax.while_loop instead |
match-case |
⚠️ | Static values only. For dynamic values, use an if-elif-else chain or jax.lax.switch instead |
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jaxify-0.0.3.tar.gz.
File metadata
- Download URL: jaxify-0.0.3.tar.gz
- Upload date:
- Size: 3.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8fce37b14335fc3d61be1beec53c40bff6077ef59a200b761b2bdc2bd88b76e5
|
|
| MD5 |
6562ddc8ca5dd85e642c78f8962fc72b
|
|
| BLAKE2b-256 |
c6e0a82bc3c5bde13f26ca10cefbbe75acf4a5d759275f4560cf625cf15e473a
|
File details
Details for the file jaxify-0.0.3-py3-none-any.whl.
File metadata
- Download URL: jaxify-0.0.3-py3-none-any.whl
- Upload date:
- Size: 4.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
71a70279c08f00f24912309e7d7aa46746bd0da19d79a6263c7202988a119f3e
|
|
| MD5 |
3bf16b8f56812ada764e4a87966124ae
|
|
| BLAKE2b-256 |
f023b0b5d5026fdebae0936b23a8668cb34605a82242ed6553d4ef9aadced490
|