Write Python. Run JAX
Project description
jaxify
Write Python. Run JAX.
| ⚠️ jaxify is an experimental project under development |
|---|
| You're welcome to try it out and report any issues! |
jaxify lets you apply JAX transformations (like @jax.jit and/or @jax.vmap) to functions with common Python constructs that JAX cannot itself handle, such as if conditions that depend on input values.
Installation
pip install jaxify
Getting started
import jax
import jax.numpy as jnp
from jaxify import jaxify
@jax.jit
@jax.vmap
@jaxify # <-- Just decorate your function with @jaxify
def absolute_value(x):
if x >= 0: # <-- If block in a JIT-compiled function
return x
else:
return -x
xs = jnp.arange(-1000, 1000)
ys = absolute_value(xs) # <-- Runs at JAX speed!
print(ys)
How it works
The @jaxify decorator transforms Python functions using a mixture of static analysis and dynamic tracing to replace unsupported Python constructs with JAX-compatible alternatives. After the transformations, the functions become traceable by JAX, enabling you to apply functional JAX transformations like @jax.jit and @jax.vmap in a seamless manner.
Compatibility status
The following Python constructs are currently supported within @jaxify-decorated functions:
🔀 Conditionals
| Construct | Works? | Notes |
|---|---|---|
if statements |
✅ | Fully supported including elif and else clauses. All branches are traced and translated to calls to jax.lax.cond |
if expressions (e.g. a if b else c) |
✅ | Traced and translated to jax.lax.cond |
⚖️ Comparisons
| Construct | Works? | Notes |
|---|---|---|
==, !=, <, >, <=, >= |
✅ | Chained comparisons (e.g. x < y <= z) are supported by translation to the equivalent chain of individual comparisons |
1️⃣ Logical operators
| Construct | Works? | Notes |
|---|---|---|
and / or |
✅ | Short-circuiting of traced values supported via translation to jax.lax.cond calls |
not |
✅ | Translates to jnp.logical_not for traced single values |
🔄 Loops
| Construct | Works? | Notes |
|---|---|---|
for loops |
❌ | Currently unsupported. Use jax.lax.fori_loop, jax.lax.scan, or jax.lax.while_loop instead |
while loops |
❌ | Currently unsupported. Use jax.lax.while_loop instead |
🎯 Pattern matching
| Construct | Works? | Notes |
|---|---|---|
match-case |
✅⚠️ | Static values only. For traced values, use an if-elif-else chain or jax.lax.switch instead |
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jaxify-0.0.4.tar.gz.
File metadata
- Download URL: jaxify-0.0.4.tar.gz
- Upload date:
- Size: 4.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5164294b72f9c2412cde58b7005a4b2547918b1e5c34a1b9e1b0240462309c4e
|
|
| MD5 |
7255cfe526b066a717b86db1d4b71c42
|
|
| BLAKE2b-256 |
5a86674fca5668f3e5f1da332faf05f1e698bfd82e11c1ab5827bda2706f5169
|
File details
Details for the file jaxify-0.0.4-py3-none-any.whl.
File metadata
- Download URL: jaxify-0.0.4-py3-none-any.whl
- Upload date:
- Size: 5.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
461e39cbd342c37f8480611a71873c38e8d79c5ddfd2b416945ba7ca18d8aec1
|
|
| MD5 |
7d15c57144437d016717a2d5dab744f3
|
|
| BLAKE2b-256 |
ffddff93d31acffe338aad17bf29998223d8532668ce3d608bd6a8509ff94ef5
|