Skip to main content

Filtered call wrappers, traceable enum state, PyTree registration, flat-vector raveling, and dtype-aware Warp interop for JAX code.

Project description

liblaf.jarp helps when JAX code mixes traceable arrays with ordinary Python metadata, and when the same program needs to cross into NVIDIA Warp. It is usually imported as from liblaf import jarp and packages a few focused tools around that boundary:

  • filter_jit and fallback_jit wrap callables while partitioning arrays away from static metadata.
  • Enum, define, frozen, array(), static(), and auto() make enum state and attrs classes flatten the way JAX expects.
  • ravel turns the dynamic leaves of a tree into one flat vector and returns a reusable Structure for round trips.
  • tree.where and tree.select apply jax.numpy.where and jax.numpy.select across matching PyTree leaves.
  • jarp.lax retries a small slice of jax.lax eagerly when JAX rejects Python-only callback logic, while preserving the wrapped primitive metadata.
  • to_warp, jarp.struct, jarp.warp.jax_callable, and jarp.warp.jax_kernel cover the common JAX-to-Warp interop paths.

📦 Installation

[!NOTE] liblaf-jarp requires Python 3.12 or newer.

Install the published package with uv:

uv add liblaf-jarp

If you want a CUDA-enabled JAX extra, pick the matching wheel set:

uv add 'liblaf-jarp[cuda12]'
uv add 'liblaf-jarp[cuda13]'

🚀 Quick Start

This example shows the core workflow: define a mixed data-and-metadata PyTree once, then reuse the same split at the function boundary.

import jax.numpy as jnp
from liblaf import jarp


@jarp.define
class Batch:
    values: object = jarp.array()
    label: str = jarp.static()


@jarp.filter_jit
def normalize(batch: Batch) -> Batch:
    centered = batch.values - jnp.mean(batch.values)
    return Batch(values=centered, label=batch.label)


batch = Batch(values=jnp.array([1.0, 2.0, 3.0]), label="train")
result = normalize(batch)

The array payload stays on the dynamic side of the partition, while the string label remains static metadata. auto() is the middle ground when a field should follow the runtime value.

Enum state can stay dynamic too, which is useful for state machines inside JAX control flow:

import enum

import jax.numpy as jnp
from liblaf import jarp


class Phase(jarp.Enum):
    START = enum.auto()
    RUNNING = enum.auto()


phase = Phase.where(jnp.array([True, False]), Phase.START, Phase.RUNNING)

jarp.ravel handles the other common workflow: flatten only the dynamic leaves into one vector and keep enough structure around to rebuild the tree later.

import jax.numpy as jnp
from liblaf import jarp


payload = {"a": jnp.zeros((3,)), "b": jnp.ones((4,)), "static": "foo"}
flat, structure = jarp.ravel(payload)
round_trip = structure.unravel(flat)

When a JAX or NumPy pipeline needs to cross into Warp, jarp.to_warp can infer vector and matrix dtypes from trailing dimensions:

from typing import Any

import jax.numpy as jnp
from liblaf import jarp


arr_wp = jarp.to_warp(jnp.zeros((5, 3), jnp.float32), (-1, Any))

For Warp structs whose field dtypes should follow the surrounding JAX precision mode, jarp.struct can specialize annotations from a small factory:

from typing import Any

import warp as wp
from liblaf import jarp


@jarp.struct
class Particle[T]:
    @classmethod
    def __annotations_factory__(cls, dtype: Any) -> dict[str, Any]:
        return {
            "position": wp.array1d(dtype=wp.types.vector(3, dtype)),
            "mass": wp.array1d(dtype=dtype),
        }


particles64 = Particle[wp.float64]()
particles_default = Particle()

Particle() uses jarp.warp.types.floating, so it follows JAX's active jax_enable_x64 setting.

When JAX control-flow primitives reject Python-only callback logic, jarp.lax.cond, switch, fori_loop, and while_loop try the corresponding jax.lax primitive first. If that call raises one of the selected JAX tracing or indexing errors, the wrapper logs the error, caches the failed metadata signature, and reruns the fallback eagerly.

For broader PyTree traversal helpers, see jarp.PyTreeProxy, jarp.partial, jarp.tree.register_generic, and the lower-level jarp.tree.codegen module. Importing jarp.tree also registers the built-in PyTree adapters for bound methods and warp.array.

🛠️ Local Development

Clone the repository, sync the workspace, and use nox for the maintained automation surface:

git clone https://github.com/liblaf/jarp.git
cd jarp
uv sync --all-groups
nox --list-sessions
nox --tags test

To build the documentation site locally:

uv run zensical build

📚 Documentation

🤝 Contributing

Issues and pull requests are welcome, especially around PyTree ergonomics, Warp integration, and edge cases that show up in real JAX code.

PR WELCOME

Contributors

🔗 Links


📝 License

Copyright © 2026 liblaf.
This project is MIT licensed.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

liblaf_jarp-0.2.1.tar.gz (27.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

liblaf_jarp-0.2.1-py3-none-any.whl (44.0 kB view details)

Uploaded Python 3

File details

Details for the file liblaf_jarp-0.2.1.tar.gz.

File metadata

  • Download URL: liblaf_jarp-0.2.1.tar.gz
  • Upload date:
  • Size: 27.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.13

File hashes

Hashes for liblaf_jarp-0.2.1.tar.gz
Algorithm Hash digest
SHA256 eecb4d654b9f690c3c07076f941a3b63ca622c00a1ef5c6eca3101b7199456d4
MD5 14b996362298d03f10244ccf5c5cfaac
BLAKE2b-256 86819ed169dabcd57daa5d1b2ca5f26a0968a713f00661ddfc96e77c521c0080

See more details on using hashes here.

Provenance

The following attestation bundles were made for liblaf_jarp-0.2.1.tar.gz:

Publisher: python-release.yaml on liblaf/jarp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file liblaf_jarp-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: liblaf_jarp-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 44.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.13

File hashes

Hashes for liblaf_jarp-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 f15de953adb3205f38354bcd6d9b29b71a0f37f2acd10d6e31389cae3c736376
MD5 3d39858d7e2c47c2895391b34bae6516
BLAKE2b-256 9d5803dbad522c86790201940e94a49bef379bc6d31c1cc89fb3e47da7ec7c7a

See more details on using hashes here.

Provenance

The following attestation bundles were made for liblaf_jarp-0.2.1-py3-none-any.whl:

Publisher: python-release.yaml on liblaf/jarp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page