Skip to main content

Autotuning for JAX/Pallas kernels

Project description

tonno

Autotuning for JAX/Pallas kernels — a lightweight @autotune decorator that sweeps block-size configs, times them on the current device, and caches the winner so the sweep only runs once per (hardware, problem shape) pair.

Inspired by Triton's @triton.autotune. Fills the same gap for Pallas (see jax-ml/jax#24340).

Install

pip install tonno          # or: uv add tonno

Usage

Stack @autotune on top of @jax.jit. The tunable parameters (keys of each config dict) must be declared as static_argnames in the jax.jit call and have defaults in the function signature — autotune injects the best values at call time.

import jax
from tonno import autotune

@autotune(configs=[
    {"BM": 32, "BN": 64},
    {"BM": 64, "BN": 128},
])
@jax.jit(static_argnames=["BM", "BN"])
def matmul(a: jax.Array, b: jax.Array, BM: int = 32, BN: int = 64) -> jax.Array:
    # BM / BN are concrete ints at compile time (static_argnames)
    return pallas_matmul(a, b, BM=BM, BN=BN)

Calling it

# First call: sweeps all configs, compiles in parallel, times sequentially,
# writes the winner to .tonno-cache/matmul.json
c = matmul(a, b)

# Subsequent calls with the same input shapes on the same device: cache hit,
# no sweep, runs immediately with the best config
c = matmul(a, b)

# Explicit override — bypass autotune entirely
c = matmul(a, b, BM=16, BN=32)

The cache key is derived from the input shapes and dtypes — exactly the same information jax.jit uses to decide whether to recompile. No explicit key= parameter needed.

Non-tunable static args

If your kernel has static args beyond the tunable ones, declare them in jax.jit as usual. They are automatically part of the cache key:

@autotune(configs=[{"BM": 32}, {"BM": 64}])
@jax.jit(static_argnames=["BM", "transpose"])
def matmul(a: jax.Array, b: jax.Array, transpose: bool = False, BM: int = 32) -> jax.Array:
    ...

How it works

  1. On first call (cache miss): dummy inputs are built from the args' shapes/dtypes via jax.ensure_compile_time_eval. All configs are compiled in parallel via ThreadPoolExecutor (XLA compilation is CPU-bound). Each compiled artifact is then timed sequentially for accurate device timing. The winner is written to .tonno-cache/<fn>.json.

  2. On subsequent calls (cache hit): the best config is loaded from disk and injected as static kwargs. JAX's own compilation cache takes over from there.

  3. Inside jax.jit / jax.grad / jax.vmap: the sweep runs as a side channel during the first trace, then the winning config is baked into the jaxpr as a compile-time constant.

API reference

autotune(
    configs: list[dict[str, Any]],  # configs to sweep; all dicts must share keys
    *,
    num_warmup: int = 1,            # warmup calls per config after compilation
    num_timing: int = 3,            # timed calls per config (median used)
)

Contract:

  • @autotune must wrap a @jax.jit-decorated function.
  • The tunable param keys must appear in static_argnames of the jax.jit call.
  • Tunable params must have defaults in the function signature.
  • Config values must be JSON-serialisable (int, float, str, bool).

Cache

Results are stored in .tonno-cache/<qualname>.json (or $TONNO_CACHE_DIR). The file is human-readable JSON; you can inspect or delete entries manually.

{
  "NVIDIA H100 80GB": {
    "{\"__arg0\":{\"dtype\":\"float32\",\"shape\":[4096,4096]}, ...}": {
      "config": {"BM": 64, "BN": 128},
      "time_ms": 0.312,
      "key_values": {...}
    }
  }
}

Example

See examples/matmul.py for a complete autotuned tiled GEMM.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tonno-0.2.2.tar.gz (7.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tonno-0.2.2-py3-none-any.whl (9.3 kB view details)

Uploaded Python 3

File details

Details for the file tonno-0.2.2.tar.gz.

File metadata

  • Download URL: tonno-0.2.2.tar.gz
  • Upload date:
  • Size: 7.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for tonno-0.2.2.tar.gz
Algorithm Hash digest
SHA256 1ae42fcf293ace2bc6172b1fff4b0768310e314150bbfeba40d30bd2ba6581a7
MD5 40bb89952acdbf80685115cbeae73a03
BLAKE2b-256 1a46d69c3115ea629db248f3350ba7fb8a3d19b8b8a645ecc87c21395f398199

See more details on using hashes here.

Provenance

The following attestation bundles were made for tonno-0.2.2.tar.gz:

Publisher: release.yml on Cusp-AI/tonno

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file tonno-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: tonno-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 9.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for tonno-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 09d6de5e6b5d753ed23781689cd5fb09a1baa7bac1bfcd78617064f8dcc520b5
MD5 351de1bc398d4781c2e4db89ff3c6418
BLAKE2b-256 552b0bc80ebabd3983ca423124e94b17925da9b81b16eb28daca19bfc5d44f78

See more details on using hashes here.

Provenance

The following attestation bundles were made for tonno-0.2.2-py3-none-any.whl:

Publisher: release.yml on Cusp-AI/tonno

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page