Minimal JAX interpreter layer for threading custom context through traced computations
Project description
Slub
Minimal JAX interpreter layer for threading a custom context through computations.
Installation
git clone https://github.com/cusp-ai-oss/slub.git
cd slub
uv sync
Quick Start
import jax, jax.numpy as jnp
from dataclasses import dataclass
from functools import partial as _partial
from functools import partial
from slub.interpreter import Interpreter, Dispatcher, InterpreterContext, reinterpret
from slub.handlers import default_primitive_handler
@_partial(jax.tree_util.register_dataclass, meta_fields=("tags",), data_fields=())
@dataclass(frozen=True)
class OpCtx(InterpreterContext):
# 'tags' marked meta so it is NOT traced or differentiated over
tags: tuple[str, ...] = ()
def add_meta(self, k: str): return OpCtx(self.tags + (k,))
def add_value(self, v): return self # no value collection here
def push(self): return self
def pop(self): return self
def tag(name: str):
def h(interpreter, ctx, eqn, invals):
ctx = ctx.add_meta(name)
return default_primitive_handler(interpreter, ctx, eqn, invals)
return h
dispatcher = Dispatcher({jax.lax.sin_p: tag("sin"), jax.lax.add_p: tag("add")})
interp = Interpreter(dispatcher=dispatcher)
@partial(reinterpret, interpreter=interp)
def f(xs):
return jnp.sin(xs).sum()
out, ctx = f(OpCtx(), jnp.linspace(0., 1., 5))
print(out, ctx.tags)
Core Concepts
Slub provides four main components that work together:
- Context: Your custom data structure that threads through the computation
- Handlers: Functions that process individual operations and update the context
- Dispatcher: Routes operations to their corresponding handlers
- Interpreter: Executes JAX computations while applying your handlers via the
reinterpretdecorator
Built-in Handlers
Slub includes handlers for common JAX operations:
| Handler | Covers | Description |
|---|---|---|
default_primitive_handler |
Primitive ops | Base handler for leaf operations |
default_jit_handler |
jax.jit |
Recursively interprets JIT-compiled functions |
default_scan_handler |
lax.scan |
Handles scan loops with carry/result threading |
default_while_handler |
lax.while_loop |
Supports context growth via initializer/updater |
default_cond_handler |
lax.cond |
Ensures branch contexts have matching structure |
Control Flow Details
When working with JAX control flow primitives, keep these behaviors in mind:
while_loop: The condition function must be pure (no context modification). If the body grows the context, provide an initializer and optionally an updater function.scan: The scan body cannot drop context leaves. You must explicitly choose which parts go into the carry versus the result.cond: All branches must produce identical context tree structures to maintain type consistency.jit: The inner graph is recursively reinterpreted with the same interpreter.
Example: While Loop with Context Growth
This example shows how to handle a while_loop that adds context during execution:
import jax
import jax.numpy as jnp
from dataclasses import dataclass
from functools import partial
from slub.interpreter import Interpreter, Dispatcher, InterpreterContext, reinterpret
from slub.handlers import (
Uninitialized,
default_while_handler,
default_primitive_handler,
)
@partial(
jax.tree_util.register_dataclass, meta_fields=("tags",), data_fields=("values",)
)
@dataclass(frozen=True)
class Ctx(InterpreterContext):
tags: tuple[str, ...] = ()
values: tuple[jax.Array, ...] = ()
def add_meta(self, tag: str):
return Ctx(self.tags + (tag,), self.values)
def add_value(self, v: jax.Array):
return Ctx(self.tags, self.values + (v,))
def push(self):
return self
def pop(self):
return self
def sin_handler(interpreter, ctx, eqn, invals):
# adds metadata and one value
ctx = ctx.add_meta("sin").add_value(jnp.array(1))
return default_primitive_handler(interpreter, ctx, eqn, invals)
def initializer(old_ctx, sentinel_ctx):
# Replace Uninitialized leaves with zeros
leaves, tree = jax.tree.flatten(sentinel_ctx)
leaves = [jnp.zeros_like(x) if isinstance(x, Uninitialized) else x for x in leaves]
return jax.tree.unflatten(tree, leaves)
def updater(old_ctx, new_ctx):
# Replace old context with new context from loop body
return new_ctx
def while_with_init(interpreter, ctx, eqn, invals):
ctx = ctx.add_meta("while").add_value(jnp.array(1))
return default_while_handler(
interpreter, ctx, eqn, invals, initializer=initializer, updater=updater
)
dispatcher = Dispatcher({"while": while_with_init, jax.lax.sin_p: sin_handler})
interpreter = Interpreter(dispatcher=dispatcher)
@partial(reinterpret, interpreter=interpreter)
def run_loop_with_init(x):
def cond(a):
return a < 3
def body(a):
_ = jnp.sin(a) # introduces extra context via sin_handler
return a + 1
return jax.lax.while_loop(cond, body, 0)
result, out_ctx = run_loop_with_init(Ctx(), jnp.array(0))
Use Cases
Slub is designed for lightweight instrumentation and experimentation:
- Instrumentation: Track operations, collect metrics, or monitor computation flow
- Provenance: Record the history and lineage of values through a computation
- Lightweight metrics: Gather statistics without heavyweight frameworks
- Research prototyping: Quickly experiment with custom computation semantics
Examples
examples/monitoring_pipeline.py— Demonstrates primitive handlers, scan, and JIT compilationnotebooks/example.ipynb— Interactive notebook with step-by-step examples
Advanced Features
For advanced usage patterns, see the source code for:
- Custom matching rules for handler dispatch
- Error policies (
RAISE,WARN,IGNORE) for mismatched contexts - Branch combiners for merging contexts from conditional branches
Development
To set up the development environment:
uv sync --group dev # Install dev dependencies
uv run pytest # Run tests
uvx pre-commit run --all-files # Run linters and formatters
License
Licensed under the Apache License, Version 2.0. See LICENSE or visit http://www.apache.org/licenses/LICENSE-2.0.
Version
The project version is defined in pyproject.toml under [project].version.
Citation
@software{slub2026,
title={Slub},
author={Cusp AI},
year={2026},
url={https://github.com/cusp-ai-oss/slub}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file slub-0.1.1.tar.gz.
File metadata
- Download URL: slub-0.1.1.tar.gz
- Upload date:
- Size: 83.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.6 {"installer":{"name":"uv","version":"0.11.6","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f5d07c8d1b9f53ae208dfa5c3840ab4d65f2f9e21ea532641794b10c8496729c
|
|
| MD5 |
5be0c4baedbb0d1b7a1b8da45663aad5
|
|
| BLAKE2b-256 |
328993ff76f3bba67c254c28071111c801f306365b08b3e6cc4edd82a5d755fc
|
File details
Details for the file slub-0.1.1-py3-none-any.whl.
File metadata
- Download URL: slub-0.1.1-py3-none-any.whl
- Upload date:
- Size: 16.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.6 {"installer":{"name":"uv","version":"0.11.6","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
547c314f10f74a3aefd68fd491f84d5511182bee2f61ba9821e1267009654fac
|
|
| MD5 |
4bd3a0ba7c6f6342d55ed46cdd98b6ec
|
|
| BLAKE2b-256 |
93813bc58931a04b0e809c1f4632965bac7877c2200c4621a51ab8bf907f982d
|