Autotuning for JAX/Pallas kernels
Project description
tonno
Autotuning for JAX/Pallas kernels — a lightweight @autotune decorator that sweeps
block-size configs, times them on the current device, and caches the winner so the
sweep only runs once per (hardware, problem shape) pair.
Inspired by Triton's @triton.autotune. Fills the same gap for Pallas (see
jax-ml/jax#24340).
Install
pip install tonno # or: uv add tonno
Usage
Stack @autotune on top of @jax.jit. The tunable parameters (keys of each config
dict) must be declared as static_argnames in the jax.jit call and have defaults
in the function signature — autotune injects the best values at call time.
import jax
from tonno import autotune
@autotune(configs=[
{"BM": 32, "BN": 64},
{"BM": 64, "BN": 128},
])
@jax.jit(static_argnames=["BM", "BN"])
def matmul(a: jax.Array, b: jax.Array, BM: int = 32, BN: int = 64) -> jax.Array:
# BM / BN are concrete ints at compile time (static_argnames)
return pallas_matmul(a, b, BM=BM, BN=BN)
Calling it
# First call: sweeps all configs, compiles in parallel, times sequentially,
# writes the winner to .tonno-cache/matmul.json
c = matmul(a, b)
# Subsequent calls with the same input shapes on the same device: cache hit,
# no sweep, runs immediately with the best config
c = matmul(a, b)
# Explicit override — bypass autotune entirely
c = matmul(a, b, BM=16, BN=32)
The cache key is derived from the input shapes and dtypes — exactly the same
information jax.jit uses to decide whether to recompile. No explicit key=
parameter needed.
Non-tunable static args
If your kernel has static args beyond the tunable ones, declare them in jax.jit
as usual. They are automatically part of the cache key:
@autotune(configs=[{"BM": 32}, {"BM": 64}])
@jax.jit(static_argnames=["BM", "transpose"])
def matmul(a: jax.Array, b: jax.Array, transpose: bool = False, BM: int = 32) -> jax.Array:
...
How it works
-
On first call (cache miss): dummy inputs are built from the args' shapes/dtypes via
jax.ensure_compile_time_eval. All configs are compiled in parallel viaThreadPoolExecutor(XLA compilation is CPU-bound). Each compiled artifact is then timed sequentially for accurate device timing. The winner is written to.tonno-cache/<fn>.json. -
On subsequent calls (cache hit): the best config is loaded from disk and injected as static kwargs. JAX's own compilation cache takes over from there.
-
Inside
jax.jit/jax.grad/jax.vmap: the sweep runs as a side channel during the first trace, then the winning config is baked into the jaxpr as a compile-time constant.
API reference
autotune(
configs: list[dict[str, Any]], # configs to sweep; all dicts must share keys
*,
num_warmup: int = 1, # warmup calls per config after compilation
num_timing: int = 3, # timed calls per config (median used)
)
Contract:
@autotunemust wrap a@jax.jit-decorated function.- The tunable param keys must appear in
static_argnamesof thejax.jitcall. - Tunable params must have defaults in the function signature.
- Config values must be JSON-serialisable (
int,float,str,bool).
Cache
Results are stored in .tonno-cache/<qualname>.json (or $TONNO_CACHE_DIR).
The file is human-readable JSON; you can inspect or delete entries manually.
{
"NVIDIA H100 80GB": {
"{\"__arg0\":{\"dtype\":\"float32\",\"shape\":[4096,4096]}, ...}": {
"config": {"BM": 64, "BN": 128},
"time_ms": 0.312,
"key_values": {...}
}
}
}
Example
See examples/matmul.py for a complete autotuned tiled GEMM.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tonno-0.2.1.tar.gz.
File metadata
- Download URL: tonno-0.2.1.tar.gz
- Upload date:
- Size: 7.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7ecfb8d54c221251ae20e5ac21cdef80809797381a975ca7d0f13bb6c36f61b9
|
|
| MD5 |
8343e79bc29904153ad4fd4f72d4922c
|
|
| BLAKE2b-256 |
b5a7a146a2c093895f3d0a7b94a7fae2473518bdebeea5bac022030f4b4db892
|
Provenance
The following attestation bundles were made for tonno-0.2.1.tar.gz:
Publisher:
release.yml on Cusp-AI/tonno
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
tonno-0.2.1.tar.gz -
Subject digest:
7ecfb8d54c221251ae20e5ac21cdef80809797381a975ca7d0f13bb6c36f61b9 - Sigstore transparency entry: 1227904012
- Sigstore integration time:
-
Permalink:
Cusp-AI/tonno@228155e42f3e1048f8cf74a16af9e8703be28b93 -
Branch / Tag:
refs/tags/v0.2.1 - Owner: https://github.com/Cusp-AI
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@228155e42f3e1048f8cf74a16af9e8703be28b93 -
Trigger Event:
push
-
Statement type:
File details
Details for the file tonno-0.2.1-py3-none-any.whl.
File metadata
- Download URL: tonno-0.2.1-py3-none-any.whl
- Upload date:
- Size: 8.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2d26ad6ba4c7ca79f1802f89d30121e65c5b1b0461b3377709981031ac93db0b
|
|
| MD5 |
e1e1da1ecf36e91727704cb2fdc55a37
|
|
| BLAKE2b-256 |
4fc4729af1c813c10523ce62818092388a92b0d0b40ce50b422b30f33342744e
|
Provenance
The following attestation bundles were made for tonno-0.2.1-py3-none-any.whl:
Publisher:
release.yml on Cusp-AI/tonno
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
tonno-0.2.1-py3-none-any.whl -
Subject digest:
2d26ad6ba4c7ca79f1802f89d30121e65c5b1b0461b3377709981031ac93db0b - Sigstore transparency entry: 1227904053
- Sigstore integration time:
-
Permalink:
Cusp-AI/tonno@228155e42f3e1048f8cf74a16af9e8703be28b93 -
Branch / Tag:
refs/tags/v0.2.1 - Owner: https://github.com/Cusp-AI
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@228155e42f3e1048f8cf74a16af9e8703be28b93 -
Trigger Event:
push
-
Statement type: