Skip to main content

Minimal JAX interpreter layer for threading custom context through traced computations

Project description

Slub

Minimal JAX interpreter layer for threading a custom context through computations.

Installation

git clone https://github.com/cusp-ai-oss/slub.git
cd slub
uv sync

Quick Start

import jax, jax.numpy as jnp
from dataclasses import dataclass
from functools import partial as _partial
from functools import partial
from slub.interpreter import Interpreter, Dispatcher, InterpreterContext, reinterpret
from slub.handlers import default_primitive_handler

@_partial(jax.tree_util.register_dataclass, meta_fields=("tags",), data_fields=())
@dataclass(frozen=True)
class OpCtx(InterpreterContext):
    # 'tags' marked meta so it is NOT traced or differentiated over
    tags: tuple[str, ...] = ()
    def add_meta(self, k: str): return OpCtx(self.tags + (k,))
    def add_value(self, v): return self  # no value collection here
    def push(self): return self
    def pop(self): return self

def tag(name: str):
    def h(interpreter, ctx, eqn, invals):
        ctx = ctx.add_meta(name)
        return default_primitive_handler(interpreter, ctx, eqn, invals)
    return h

dispatcher = Dispatcher({jax.lax.sin_p: tag("sin"), jax.lax.add_p: tag("add")})
interp = Interpreter(dispatcher=dispatcher)

@partial(reinterpret, interpreter=interp)
def f(xs):
    return jnp.sin(xs).sum()

out, ctx = f(OpCtx(), jnp.linspace(0., 1., 5))
print(out, ctx.tags)

Core Concepts

Slub provides four main components that work together:

  • Context: Your custom data structure that threads through the computation
  • Handlers: Functions that process individual operations and update the context
  • Dispatcher: Routes operations to their corresponding handlers
  • Interpreter: Executes JAX computations while applying your handlers via the reinterpret decorator

Built-in Handlers

Slub includes handlers for common JAX operations:

Handler Covers Description
default_primitive_handler Primitive ops Base handler for leaf operations
default_jit_handler jax.jit Recursively interprets JIT-compiled functions
default_scan_handler lax.scan Handles scan loops with carry/result threading
default_while_handler lax.while_loop Supports context growth via initializer/updater
default_cond_handler lax.cond Ensures branch contexts have matching structure

Control Flow Details

When working with JAX control flow primitives, keep these behaviors in mind:

  • while_loop: The condition function must be pure (no context modification). If the body grows the context, provide an initializer and optionally an updater function.
  • scan: The scan body cannot drop context leaves. You must explicitly choose which parts go into the carry versus the result.
  • cond: All branches must produce identical context tree structures to maintain type consistency.
  • jit: The inner graph is recursively reinterpreted with the same interpreter.

Example: While Loop with Context Growth

This example shows how to handle a while_loop that adds context during execution:

import jax
import jax.numpy as jnp
from dataclasses import dataclass
from functools import partial
from slub.interpreter import Interpreter, Dispatcher, InterpreterContext, reinterpret
from slub.handlers import (
    Uninitialized,
    default_while_handler,
    default_primitive_handler,
)


@partial(
    jax.tree_util.register_dataclass, meta_fields=("tags",), data_fields=("values",)
)
@dataclass(frozen=True)
class Ctx(InterpreterContext):
    tags: tuple[str, ...] = ()
    values: tuple[jax.Array, ...] = ()

    def add_meta(self, tag: str):
        return Ctx(self.tags + (tag,), self.values)

    def add_value(self, v: jax.Array):
        return Ctx(self.tags, self.values + (v,))

    def push(self):
        return self

    def pop(self):
        return self


def sin_handler(interpreter, ctx, eqn, invals):
    # adds metadata and one value
    ctx = ctx.add_meta("sin").add_value(jnp.array(1))
    return default_primitive_handler(interpreter, ctx, eqn, invals)


def initializer(old_ctx, sentinel_ctx):
    # Replace Uninitialized leaves with zeros
    leaves, tree = jax.tree.flatten(sentinel_ctx)
    leaves = [jnp.zeros_like(x) if isinstance(x, Uninitialized) else x for x in leaves]
    return jax.tree.unflatten(tree, leaves)


def updater(old_ctx, new_ctx):
    # Replace old context with new context from loop body
    return new_ctx


def while_with_init(interpreter, ctx, eqn, invals):
    ctx = ctx.add_meta("while").add_value(jnp.array(1))
    return default_while_handler(
        interpreter, ctx, eqn, invals, initializer=initializer, updater=updater
    )


dispatcher = Dispatcher({"while": while_with_init, jax.lax.sin_p: sin_handler})
interpreter = Interpreter(dispatcher=dispatcher)


@partial(reinterpret, interpreter=interpreter)
def run_loop_with_init(x):
    def cond(a):
        return a < 3

    def body(a):
        _ = jnp.sin(a)  # introduces extra context via sin_handler
        return a + 1

    return jax.lax.while_loop(cond, body, 0)


result, out_ctx = run_loop_with_init(Ctx(), jnp.array(0))

Use Cases

Slub is designed for lightweight instrumentation and experimentation:

  • Instrumentation: Track operations, collect metrics, or monitor computation flow
  • Provenance: Record the history and lineage of values through a computation
  • Lightweight metrics: Gather statistics without heavyweight frameworks
  • Research prototyping: Quickly experiment with custom computation semantics

Examples

  • examples/monitoring_pipeline.py — Demonstrates primitive handlers, scan, and JIT compilation
  • notebooks/example.ipynb — Interactive notebook with step-by-step examples

Advanced Features

For advanced usage patterns, see the source code for:

  • Custom matching rules for handler dispatch
  • Error policies (RAISE, WARN, IGNORE) for mismatched contexts
  • Branch combiners for merging contexts from conditional branches

Development

To set up the development environment:

uv sync --group dev          # Install dev dependencies
uv run pytest                # Run tests
uvx pre-commit run --all-files  # Run linters and formatters

License

Licensed under the Apache License, Version 2.0. See LICENSE or visit http://www.apache.org/licenses/LICENSE-2.0.

Version

The project version is defined in pyproject.toml under [project].version.

Citation

@software{slub2026,
  title={Slub},
  author={Cusp AI},
  year={2026},
  url={https://github.com/cusp-ai-oss/slub}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

slub-0.1.0.tar.gz (83.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

slub-0.1.0-py3-none-any.whl (16.3 kB view details)

Uploaded Python 3

File details

Details for the file slub-0.1.0.tar.gz.

File metadata

  • Download URL: slub-0.1.0.tar.gz
  • Upload date:
  • Size: 83.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.6 {"installer":{"name":"uv","version":"0.11.6","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for slub-0.1.0.tar.gz
Algorithm Hash digest
SHA256 34d272ceff2bea5809decc6cb93c10e8c89d9364c3377fb141860e46e02cf137
MD5 df0217ec8fd1802021c937af8ddafb6b
BLAKE2b-256 65f06849fd4c5b7b36a977fa3cf63ab5f075a9a92695a573406abfc892a53734

See more details on using hashes here.

File details

Details for the file slub-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: slub-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 16.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.6 {"installer":{"name":"uv","version":"0.11.6","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for slub-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0b88dc0cf1753bc5dd702cf16f996913e075401a3dceb1255c663a732c08591e
MD5 3a9c740f8929deae6d0e486d13ec6a1b
BLAKE2b-256 48b4f313a9838a89a452b2dff40dc15293bfa504d53b936e7cbabd990a12b7e3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page