Filtered call wrappers, traceable enum state, PyTree registration, flat-vector raveling, and dtype-aware Warp interop for JAX code.
Project description
liblaf.jarp helps when JAX code mixes traceable arrays with ordinary Python
metadata, and when the same program needs to cross into NVIDIA Warp. It is
usually imported as from liblaf import jarp and packages a few focused tools
around that boundary:
filter_jitandfallback_jitwrap callables while partitioning arrays away from static metadata.Enum,define,frozen,array(),static(), andauto()make enum state andattrsclasses flatten the way JAX expects.ravelturns the dynamic leaves of a tree into one flat vector and returns a reusableStructurefor round trips.tree.whereandtree.selectapplyjax.numpy.whereandjax.numpy.selectacross matching PyTree leaves.jarp.laxretries a small slice ofjax.laxeagerly when JAX rejects Python-only callback logic, while preserving the wrapped primitive metadata.to_warp,jarp.struct,jarp.warp.jax_callable, andjarp.warp.jax_kernelcover the common JAX-to-Warp interop paths.
📦 Installation
[!NOTE]
liblaf-jarprequires Python 3.12 or newer.
Install the published package with uv:
uv add liblaf-jarp
If you want a CUDA-enabled JAX extra, pick the matching wheel set:
uv add 'liblaf-jarp[cuda12]'
uv add 'liblaf-jarp[cuda13]'
🚀 Quick Start
This example shows the core workflow: define a mixed data-and-metadata PyTree once, then reuse the same split at the function boundary.
import jax.numpy as jnp
from liblaf import jarp
@jarp.define
class Batch:
values: object = jarp.array()
label: str = jarp.static()
@jarp.filter_jit
def normalize(batch: Batch) -> Batch:
centered = batch.values - jnp.mean(batch.values)
return Batch(values=centered, label=batch.label)
batch = Batch(values=jnp.array([1.0, 2.0, 3.0]), label="train")
result = normalize(batch)
The array payload stays on the dynamic side of the partition, while the string
label remains static metadata. auto() is the middle ground when a field
should follow the runtime value.
Enum state can stay dynamic too, which is useful for state machines inside JAX control flow:
import enum
import jax.numpy as jnp
from liblaf import jarp
class Phase(jarp.Enum):
START = enum.auto()
RUNNING = enum.auto()
phase = Phase.where(jnp.array([True, False]), Phase.START, Phase.RUNNING)
jarp.ravel handles the other common workflow: flatten only the dynamic leaves
into one vector and keep enough structure around to rebuild the tree later.
import jax.numpy as jnp
from liblaf import jarp
payload = {"a": jnp.zeros((3,)), "b": jnp.ones((4,)), "static": "foo"}
flat, structure = jarp.ravel(payload)
round_trip = structure.unravel(flat)
When a JAX or NumPy pipeline needs to cross into Warp, jarp.to_warp can infer
vector and matrix dtypes from trailing dimensions:
from typing import Any
import jax.numpy as jnp
from liblaf import jarp
arr_wp = jarp.to_warp(jnp.zeros((5, 3), jnp.float32), (-1, Any))
For Warp structs whose field dtypes should follow the surrounding JAX precision
mode, jarp.struct can specialize annotations from a small factory:
from typing import Any
import warp as wp
from liblaf import jarp
@jarp.struct
class Particle[T]:
@classmethod
def __annotations_factory__(cls, dtype: Any) -> dict[str, Any]:
return {
"position": wp.array1d(dtype=wp.types.vector(3, dtype)),
"mass": wp.array1d(dtype=dtype),
}
particles64 = Particle[wp.float64]()
particles_default = Particle()
Particle() uses jarp.warp.types.floating, so it follows JAX's active
jax_enable_x64 setting.
When JAX control-flow primitives reject Python-only callback logic,
jarp.lax.cond, switch, fori_loop, and while_loop try the corresponding
jax.lax primitive first. If that call raises one of the selected JAX tracing
or indexing errors, the wrapper logs the error, caches the failed metadata
signature, and reruns the fallback eagerly.
For broader PyTree traversal helpers, see jarp.PyTreeProxy,
jarp.partial, jarp.tree.register_generic, and the lower-level
jarp.tree.codegen module. Importing jarp.tree also registers the built-in
PyTree adapters for bound methods and warp.array.
🛠️ Local Development
Clone the repository, sync the workspace, and use nox for the maintained
automation surface:
git clone https://github.com/liblaf/jarp.git
cd jarp
uv sync --all-groups
nox --list-sessions
nox --tags test
To build the documentation site locally:
uv run zensical build
📚 Documentation
- Documentation site
- Getting started guide
- Call wrappers guide
- PyTree workflows
- Warp interop guide
- API reference
🤝 Contributing
Issues and pull requests are welcome, especially around PyTree ergonomics, Warp integration, and edge cases that show up in real JAX code.
🔗 Links
📝 License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file liblaf_jarp-0.2.1.tar.gz.
File metadata
- Download URL: liblaf_jarp-0.2.1.tar.gz
- Upload date:
- Size: 27.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
eecb4d654b9f690c3c07076f941a3b63ca622c00a1ef5c6eca3101b7199456d4
|
|
| MD5 |
14b996362298d03f10244ccf5c5cfaac
|
|
| BLAKE2b-256 |
86819ed169dabcd57daa5d1b2ca5f26a0968a713f00661ddfc96e77c521c0080
|
Provenance
The following attestation bundles were made for liblaf_jarp-0.2.1.tar.gz:
Publisher:
python-release.yaml on liblaf/jarp
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
liblaf_jarp-0.2.1.tar.gz -
Subject digest:
eecb4d654b9f690c3c07076f941a3b63ca622c00a1ef5c6eca3101b7199456d4 - Sigstore transparency entry: 1519165572
- Sigstore integration time:
-
Permalink:
liblaf/jarp@4948cae8cff89459b7084f667cde21d6782702a1 -
Branch / Tag:
refs/tags/v0.2.1 - Owner: https://github.com/liblaf
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-release.yaml@4948cae8cff89459b7084f667cde21d6782702a1 -
Trigger Event:
release
-
Statement type:
File details
Details for the file liblaf_jarp-0.2.1-py3-none-any.whl.
File metadata
- Download URL: liblaf_jarp-0.2.1-py3-none-any.whl
- Upload date:
- Size: 44.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f15de953adb3205f38354bcd6d9b29b71a0f37f2acd10d6e31389cae3c736376
|
|
| MD5 |
3d39858d7e2c47c2895391b34bae6516
|
|
| BLAKE2b-256 |
9d5803dbad522c86790201940e94a49bef379bc6d31c1cc89fb3e47da7ec7c7a
|
Provenance
The following attestation bundles were made for liblaf_jarp-0.2.1-py3-none-any.whl:
Publisher:
python-release.yaml on liblaf/jarp
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
liblaf_jarp-0.2.1-py3-none-any.whl -
Subject digest:
f15de953adb3205f38354bcd6d9b29b71a0f37f2acd10d6e31389cae3c736376 - Sigstore transparency entry: 1519165619
- Sigstore integration time:
-
Permalink:
liblaf/jarp@4948cae8cff89459b7084f667cde21d6782702a1 -
Branch / Tag:
refs/tags/v0.2.1 - Owner: https://github.com/liblaf
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-release.yaml@4948cae8cff89459b7084f667cde21d6782702a1 -
Trigger Event:
release
-
Statement type: