Skip to main content

Write Python. Run JAX

Project description

jaxify

Write Python. Run JAX.

CI Codecov Ruff ty uv Publish PyPI PyPI - Python Version

⚠️ jaxify is an experimental project under development
You're welcome to try it out and report any issues!

jaxify lets you apply JAX transformations (like @jax.jit and/or @jax.vmap) to functions with common Python constructs that JAX cannot itself handle, such as if conditions that depend on input values.

Installation

pip install jaxify

Getting started

import jax
import jax.numpy as jnp
from jaxify import jaxify

@jax.jit
@jax.vmap
@jaxify  # <-- Just decorate your function with @jaxify
def absolute_value(x):
    if x >= 0:  # <-- If block in a JIT-compiled function
        return x
    else:
        return -x

xs = jnp.arange(-1000, 1000)
ys = absolute_value(xs)  # <-- Runs at JAX speed!
print(ys)

How it works

The @jaxify decorator transforms Python functions using a mixture of static analysis and dynamic tracing to replace unsupported Python constructs with JAX-compatible alternatives. After the transformations, the functions become traceable by JAX, enabling you to apply functional JAX transformations like @jax.jit and @jax.vmap in a seamless manner.

Compatibility status

The following Python constructs are currently supported within @jaxify-decorated functions:

🔀 Conditionals

Construct Works? Notes
if statements Fully supported including elif and else clauses. All branches are traced and translated to calls to jax.lax.cond
if expressions (e.g. a if b else c) Traced and translated to jax.lax.cond

⚖️ Comparisons

Construct Works? Notes
==, !=, <, >, <=, >= Chained comparisons (e.g. x < y <= z) are supported by translation to the equivalent chain of individual comparisons

1️⃣ Logical operators

Construct Works? Notes
and / or Short-circuiting of traced values supported via translation to jax.lax.cond calls
not Translates to jnp.logical_not for traced single values

🔄 Loops

Construct Works? Notes
for loops Currently unsupported. Use jax.lax.fori_loop, jax.lax.scan, or jax.lax.while_loop instead
while loops Currently unsupported. Use jax.lax.while_loop instead

🎯 Pattern matching

Construct Works? Notes
match-case ✅⚠️ Static values only. For traced values, use an if-elif-else chain or jax.lax.switch instead

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxify-0.0.4.tar.gz (4.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxify-0.0.4-py3-none-any.whl (5.4 kB view details)

Uploaded Python 3

File details

Details for the file jaxify-0.0.4.tar.gz.

File metadata

  • Download URL: jaxify-0.0.4.tar.gz
  • Upload date:
  • Size: 4.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for jaxify-0.0.4.tar.gz
Algorithm Hash digest
SHA256 5164294b72f9c2412cde58b7005a4b2547918b1e5c34a1b9e1b0240462309c4e
MD5 7255cfe526b066a717b86db1d4b71c42
BLAKE2b-256 5a86674fca5668f3e5f1da332faf05f1e698bfd82e11c1ab5827bda2706f5169

See more details on using hashes here.

File details

Details for the file jaxify-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: jaxify-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 5.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for jaxify-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 461e39cbd342c37f8480611a71873c38e8d79c5ddfd2b416945ba7ca18d8aec1
MD5 7d15c57144437d016717a2d5dab744f3
BLAKE2b-256 ffddff93d31acffe338aad17bf29998223d8532668ce3d608bd6a8509ff94ef5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page