Skip to main content

Filtered call wrappers, PyTree registration, flat-vector raveling, and dtype-aware Warp interop for JAX code.

Project description

liblaf.jarp helps when JAX code mixes traceable arrays with ordinary Python metadata, and when the same program needs to cross into NVIDIA Warp. It is usually imported as from liblaf import jarp and packages a few focused tools around that boundary:

  • filter_jit and fallback_jit wrap callables while partitioning arrays away from static metadata.
  • define, frozen, array(), static(), and auto() make attrs classes flatten the way JAX expects.
  • ravel turns the dynamic leaves of a tree into one flat vector and returns a reusable Structure for round trips.
  • jarp.lax retries a small slice of jax.lax eagerly when JAX rejects Python-only callback logic, while preserving the wrapped primitive metadata.
  • to_warp, jarp.struct, jarp.warp.jax_callable, and jarp.warp.jax_kernel cover the common JAX-to-Warp interop paths.

📦 Installation

[!NOTE] liblaf-jarp requires Python 3.12 or newer.

Install the published package with uv:

uv add liblaf-jarp

If you want a CUDA-enabled JAX extra, pick the matching wheel set:

uv add 'liblaf-jarp[cuda12]'
uv add 'liblaf-jarp[cuda13]'

🚀 Quick Start

This example shows the core workflow: define a mixed data-and-metadata PyTree once, then reuse the same split at the function boundary.

import jax.numpy as jnp
from liblaf import jarp


@jarp.define
class Batch:
    values: object = jarp.array()
    label: str = jarp.static()


@jarp.filter_jit
def normalize(batch: Batch) -> Batch:
    centered = batch.values - jnp.mean(batch.values)
    return Batch(values=centered, label=batch.label)


batch = Batch(values=jnp.array([1.0, 2.0, 3.0]), label="train")
result = normalize(batch)

The array payload stays on the dynamic side of the partition, while the string label remains static metadata. auto() is the middle ground when a field should follow the runtime value.

jarp.ravel handles the other common workflow: flatten only the dynamic leaves into one vector and keep enough structure around to rebuild the tree later.

import jax.numpy as jnp
from liblaf import jarp


payload = {"a": jnp.zeros((3,)), "b": jnp.ones((4,)), "static": "foo"}
flat, structure = jarp.ravel(payload)
round_trip = structure.unravel(flat)

When a JAX or NumPy pipeline needs to cross into Warp, jarp.to_warp can infer vector and matrix dtypes from trailing dimensions:

from typing import Any

import jax.numpy as jnp
from liblaf import jarp


arr_wp = jarp.to_warp(jnp.zeros((5, 3), jnp.float32), (-1, Any))

For Warp structs whose field dtypes should follow the surrounding JAX precision mode, jarp.struct can specialize annotations from a small factory:

from typing import Any

import warp as wp
from liblaf import jarp


@jarp.struct
class Particle[T]:
    @classmethod
    def __annotations_factory__(cls, dtype: Any) -> dict[str, Any]:
        return {
            "position": wp.array1d(dtype=wp.types.vector(3, dtype)),
            "mass": wp.array1d(dtype=dtype),
        }


particles64 = Particle[wp.float64]()
particles_default = Particle()

Particle() uses jarp.warp.types.floating, so it follows JAX's active jax_enable_x64 setting.

When JAX control-flow primitives reject Python-only callback logic, jarp.lax.cond, switch, fori_loop, and while_loop try the corresponding jax.lax primitive first. If that call raises one of the selected JAX tracing or indexing errors, the wrapper logs the error, caches the failed metadata signature, and reruns the fallback eagerly.

For broader PyTree traversal helpers, see jarp.PyTreeProxy, jarp.partial, jarp.tree.register_generic, and the lower-level jarp.tree.codegen module. Importing jarp.tree also registers the built-in PyTree adapters for bound methods and warp.array.

🛠️ Local Development

Clone the repository, sync the workspace, and use nox for the maintained automation surface:

git clone https://github.com/liblaf/jarp.git
cd jarp
uv sync --all-groups
nox --list-sessions
nox --tags test

To build the documentation site locally:

uv run zensical build

📚 Documentation

🤝 Contributing

Issues and pull requests are welcome, especially around PyTree ergonomics, Warp integration, and edge cases that show up in real JAX code.

PR WELCOME

Contributors

🔗 Links


📝 License

Copyright © 2026 liblaf.
This project is MIT licensed.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

liblaf_jarp-0.2.0.tar.gz (25.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

liblaf_jarp-0.2.0-py3-none-any.whl (40.6 kB view details)

Uploaded Python 3

File details

Details for the file liblaf_jarp-0.2.0.tar.gz.

File metadata

  • Download URL: liblaf_jarp-0.2.0.tar.gz
  • Upload date:
  • Size: 25.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.13

File hashes

Hashes for liblaf_jarp-0.2.0.tar.gz
Algorithm Hash digest
SHA256 e00f3667c7bd49d3283f742a425866f672e68e9b151615c44c4e7557755edea8
MD5 47187e8bc5bb8cd5ac00a857094eb430
BLAKE2b-256 d45a7b653f526fb7b958573c4b0154ae52253ff00651f56d67a9577381ff5455

See more details on using hashes here.

Provenance

The following attestation bundles were made for liblaf_jarp-0.2.0.tar.gz:

Publisher: python-release.yaml on liblaf/jarp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file liblaf_jarp-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: liblaf_jarp-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 40.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.13

File hashes

Hashes for liblaf_jarp-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0afcbfe98452ddddfa53d7ca43122e20d50c8d24a05e8ba184ad20261c354005
MD5 c5a572907a3eab8d1e810256b53c8cfe
BLAKE2b-256 02a4e169d35942c456e40e8b139d0c4b08aec4c0bfcea4c251bf1c2cae7df0e2

See more details on using hashes here.

Provenance

The following attestation bundles were made for liblaf_jarp-0.2.0-py3-none-any.whl:

Publisher: python-release.yaml on liblaf/jarp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page