Skip to main content

Autotuning for JAX/Pallas kernels

Project description

tonno

Autotuning for JAX/Pallas kernels — a lightweight @autotune decorator that sweeps block-size configs, times them on the current device, and caches the winner so the sweep only runs once per (hardware, problem shape) pair.

Inspired by Triton's @triton.autotune. Fills the same gap for Pallas (see jax-ml/jax#24340).

Install

pip install tonno          # or: uv add tonno

Usage

Stack @autotune on top of @jax.jit. The tunable parameters (keys of each config dict) must be declared as static_argnames in the jax.jit call and have defaults in the function signature — autotune injects the best values at call time.

import jax
from tonno import autotune

@autotune(configs=[
    {"BM": 32, "BN": 64},
    {"BM": 64, "BN": 128},
])
@jax.jit(static_argnames=["BM", "BN"])
def matmul(a: jax.Array, b: jax.Array, BM: int = 32, BN: int = 64) -> jax.Array:
    # BM / BN are concrete ints at compile time (static_argnames)
    return pallas_matmul(a, b, BM=BM, BN=BN)

Calling it

# First call: sweeps all configs, compiles in parallel, times sequentially,
# writes the winner to .tonno-cache/matmul.json
c = matmul(a, b)

# Subsequent calls with the same input shapes on the same device: cache hit,
# no sweep, runs immediately with the best config
c = matmul(a, b)

# Explicit override — bypass autotune entirely
c = matmul(a, b, BM=16, BN=32)

The cache key is derived from the input shapes and dtypes — exactly the same information jax.jit uses to decide whether to recompile. No explicit key= parameter needed.

Non-tunable static args

If your kernel has static args beyond the tunable ones, declare them in jax.jit as usual. They are automatically part of the cache key:

@autotune(configs=[{"BM": 32}, {"BM": 64}])
@jax.jit(static_argnames=["BM", "transpose"])
def matmul(a: jax.Array, b: jax.Array, transpose: bool = False, BM: int = 32) -> jax.Array:
    ...

How it works

  1. On first call (cache miss): dummy inputs are built from the args' shapes/dtypes via jax.ensure_compile_time_eval. All configs are compiled in parallel via ThreadPoolExecutor (XLA compilation is CPU-bound). Each compiled artifact is then timed sequentially for accurate device timing. The winner is written to .tonno-cache/<fn>.json.

  2. On subsequent calls (cache hit): the best config is loaded from disk and injected as static kwargs. JAX's own compilation cache takes over from there.

  3. Inside jax.jit / jax.grad / jax.vmap: the sweep runs as a side channel during the first trace, then the winning config is baked into the jaxpr as a compile-time constant.

API reference

autotune(
    configs: list[dict[str, Any]],  # configs to sweep; all dicts must share keys
    *,
    num_warmup: int = 1,            # warmup calls per config after compilation
    num_timing: int = 3,            # timed calls per config (median used)
)

Contract:

  • @autotune must wrap a @jax.jit-decorated function.
  • The tunable param keys must appear in static_argnames of the jax.jit call.
  • Tunable params must have defaults in the function signature.
  • Config values must be JSON-serialisable (int, float, str, bool).

Cache

Results are stored in .tonno-cache/<qualname>.json (or $TONNO_CACHE_DIR). The file is human-readable JSON; you can inspect or delete entries manually.

{
  "NVIDIA H100 80GB": {
    "{\"__arg0\":{\"dtype\":\"float32\",\"shape\":[4096,4096]}, ...}": {
      "config": {"BM": 64, "BN": 128},
      "time_ms": 0.312,
      "key_values": {...}
    }
  }
}

Example

See examples/matmul.py for a complete autotuned tiled GEMM.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tonno-0.2.3.tar.gz (7.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tonno-0.2.3-py3-none-any.whl (9.4 kB view details)

Uploaded Python 3

File details

Details for the file tonno-0.2.3.tar.gz.

File metadata

  • Download URL: tonno-0.2.3.tar.gz
  • Upload date:
  • Size: 7.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for tonno-0.2.3.tar.gz
Algorithm Hash digest
SHA256 97a1cea39eaca0526a1f7899627b7ba1fc6835bf74416bb5b633ac22c2f91520
MD5 8c92152d17bae03176cdf7bc27c11ced
BLAKE2b-256 43e76f560a6d6edb9b4c43e00c09137312b48b1e51102eddca067088da36f32d

See more details on using hashes here.

Provenance

The following attestation bundles were made for tonno-0.2.3.tar.gz:

Publisher: release.yml on Cusp-AI/tonno

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file tonno-0.2.3-py3-none-any.whl.

File metadata

  • Download URL: tonno-0.2.3-py3-none-any.whl
  • Upload date:
  • Size: 9.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for tonno-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 3704053331cf16531dc67054457131d14300d77d6b50a5282f3c3083c172bbef
MD5 3285a216dbe8bd202b6334c198c9e76e
BLAKE2b-256 907ec8190872be78436358c8c9da343ff7cb0e28463464ad883a547bdf9ab6d7

See more details on using hashes here.

Provenance

The following attestation bundles were made for tonno-0.2.3-py3-none-any.whl:

Publisher: release.yml on Cusp-AI/tonno

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page