Autotuning for JAX/Pallas kernels
Project description
tonno
Autotuning for JAX/Pallas kernels — a lightweight @autotune decorator that sweeps
block-size configs, times them on the current device, and caches the winner so the
sweep only runs once per (hardware, problem shape) pair.
Inspired by Triton's @triton.autotune. Fills the same gap for Pallas (see
jax-ml/jax#24340).
Install
pip install tonno # or: uv add tonno
Usage
1. Define a config type
Use a NamedTuple — hashable, fully typed, JSON-serialisable out of the box:
from typing import NamedTuple
class GemmConfig(NamedTuple):
bm: int # output tile rows
bn: int # output tile cols
2. Decorate your Pallas kernel
The config is the first positional argument. Autotune injects it; callers never pass it directly. Key kwargs identify the problem shape and are used as the cache key.
from jax.experimental import pallas as pl
from tonno import autotune
@autotune(
configs=[
GemmConfig(bm=16, bn=64),
GemmConfig(bm=32, bn=128),
GemmConfig(bm=64, bn=128),
],
key=["M", "K", "N"],
)
def matmul(
cfg: GemmConfig,
a: jax.Array,
b: jax.Array,
*,
M: int | None = None, # key param — must have a default
K: int | None = None,
N: int | None = None,
) -> jax.Array:
# cfg.bm / cfg.bn are concrete ints at JIT compile time (static_argnums=0)
# Derive grid from array shapes, not from key params (those are popped)
return pl.pallas_call(
lambda a_ref, b_ref, c_ref: ...,
out_shape=jax.ShapeDtypeStruct((a.shape[0], b.shape[1]), a.dtype),
grid=(a.shape[0] // cfg.bm, b.shape[1] // cfg.bn),
...
)(a, b)
3. Call it
# First call: sweeps all configs, compiles in parallel, times sequentially,
# writes the best GemmConfig to .tonno-cache/matmul.json
c = matmul(a, b, M=4096, K=4096, N=4096)
# Subsequent calls with the same (M, K, N) on the same device: cache hit,
# no sweep, runs immediately with the best config
c = matmul(a, b, M=4096, K=4096, N=4096)
The cache is per-device (H100-80GB, TPU-v4, cpu, …) so configs transfer
correctly across runs on the same hardware.
How it works
-
On first call (cache miss): dummy inputs are built from the args' shapes/dtypes. All configs are compiled in parallel via
ThreadPoolExecutor(XLA compilation is CPU-bound). Each compiled artifact is then timed sequentially on the dummy inputs for accurate device timing. The winner is written to.tonno-cache/<fn>.json. -
On subsequent calls (cache hit): the best config is loaded from disk and injected as
static_argnums=0. JAX's own compilation cache takes over from there. -
Inside
jax.jit: the sweep runs as a side channel during the first trace (viajax.ensure_compile_time_eval), then the winning config is baked into the jaxpr as a compile-time static.
Config types
Any hashable type works. NamedTuple is recommended because it is:
- Hashable → required by
static_argnums - Typed →
cfg.bm: int, notcfg.bm: int | float | str | bool - JSON-serialisable natively (tuple → list; default decoder reconstructs via
T(*data))
# NamedTuple — recommended
class KC(NamedTuple):
bm: int
bk: int
# frozen dataclass — works with explicit encode/decode
from dataclasses import dataclass
import dataclasses
@dataclass(frozen=True)
class KC:
bm: int
bk: int
@autotune(
configs=[KC(32, 64), KC(64, 32)],
key=["N"],
encode=dataclasses.asdict,
decode=lambda d: KC(**d),
)
def kernel(cfg: KC, x, *, N=None): ...
API reference
autotune(
configs: Iterable[_C], # configs to sweep, must be hashable
key: Sequence[str], # kwargs naming the problem shape
*,
num_warmup: int = 1, # warmup iterations before timing
num_timing: int = 3, # timed iterations (median used)
encode: Callable | None = None, # config → JSON-serialisable (default: identity)
decode: Callable | None = None, # JSON-loaded → config (default: T(*data))
)
Rules for the decorated function:
- Config is the first positional argument, typed as
_C. - Key params must have a default value (
N: int | None = None) — they are popped by autotune and never forwarded to the function body. - Derive Pallas grids from array shapes (
a.shape[0] // cfg.bm), not from key params (which areNoneinside the function). - All configs must have the same pytree structure (same type, same fields).
Cache
Results are stored in .tonno-cache/<qualname>.json (or $TONNO_CACHE_DIR).
The file is human-readable JSON; you can inspect or delete entries manually.
{
"NVIDIA H100 80GB": {
"{\"M\":4096,\"K\":4096,\"N\":4096}": {
"config": [64, 128],
"time_ms": 0.312,
"key_values": {"M": 4096, "K": 4096, "N": 4096}
}
}
}
Example
See examples/matmul.py for a complete autotuned tiled GEMM.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tonno-0.1.0.tar.gz.
File metadata
- Download URL: tonno-0.1.0.tar.gz
- Upload date:
- Size: 7.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dc78cd067553f7572d454cb06436b2dc58341c91406989b8ad61d21d73be4ef4
|
|
| MD5 |
bdf29c0a3f598d4557cea145e63ef6a3
|
|
| BLAKE2b-256 |
8b00f83ae094bdf9014ec421058ab10628ce4cec8a473539e6967516848f1e79
|
Provenance
The following attestation bundles were made for tonno-0.1.0.tar.gz:
Publisher:
release.yml on Cusp-AI/tonno
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
tonno-0.1.0.tar.gz -
Subject digest:
dc78cd067553f7572d454cb06436b2dc58341c91406989b8ad61d21d73be4ef4 - Sigstore transparency entry: 1215575953
- Sigstore integration time:
-
Permalink:
Cusp-AI/tonno@5a9428313f32bd3cbf0b04dc787254ecc4865f9c -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/Cusp-AI
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@5a9428313f32bd3cbf0b04dc787254ecc4865f9c -
Trigger Event:
push
-
Statement type:
File details
Details for the file tonno-0.1.0-py3-none-any.whl.
File metadata
- Download URL: tonno-0.1.0-py3-none-any.whl
- Upload date:
- Size: 9.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
92b4eac3aeaa847ce78d890778c75ed6f36b1bb24d0b28703347858c78b9075b
|
|
| MD5 |
0902d4c2d7ca7692c0b3b3f5e99e7171
|
|
| BLAKE2b-256 |
002a230803d5efb7afa0c331c470c793477842643068612480135297ec8ebe17
|
Provenance
The following attestation bundles were made for tonno-0.1.0-py3-none-any.whl:
Publisher:
release.yml on Cusp-AI/tonno
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
tonno-0.1.0-py3-none-any.whl -
Subject digest:
92b4eac3aeaa847ce78d890778c75ed6f36b1bb24d0b28703347858c78b9075b - Sigstore transparency entry: 1215575990
- Sigstore integration time:
-
Permalink:
Cusp-AI/tonno@5a9428313f32bd3cbf0b04dc787254ecc4865f9c -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/Cusp-AI
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@5a9428313f32bd3cbf0b04dc787254ecc4865f9c -
Trigger Event:
push
-
Statement type: