Skip to main content

A Pallas Custom Kernel Library.

Project description

Tokamax

CI PyPI version Static Badge

Tokamax is a library of custom accelerator kernels, supporting both NVIDIA GPUs and Google TPUs. Tokamax provides state-of-the-art custom kernel implementations built on top of JAX and Pallas.

Tokamax also provides tooling for users to build and autotune their own custom accelerator kernels.

Status

Tokamax is still heavily under development. Incomplete features and API changes are to be expected.

We currently support the following GPU kernels:

And the following for both GPU and TPU:

And the following TPU kernels:

Installation

The latest Tokamax PyPI release:

pip install -U tokamax

The latest bleeding edge version from Github, with no stability guarantees:

pip install git+https://github.com/openxla/tokamax.git

Using Tokamax

Consider a function containing Tokamax functions running on an H100 GPU:

import jax
import jax.numpy as jnp
import tokamax

def loss(x, scale):
  x = tokamax.layer_norm(
      x, scale=scale, offset=None, implementation="triton"
  )
  x = tokamax.dot_product_attention(x, x, x, implementation="xla_chunked")
  x = tokamax.layer_norm(x, scale=scale, offset=None, implementation=None)
  x = tokamax.dot_product_attention(x, x, x, implementation="mosaic")
  return jnp.sum(x)

f_grad = jax.jit(jax.grad(loss))

With implementation=None, Tokamax is allowed to select the best implementation for each kernel shape. It is even allowed to choose different implementations for the forward pass and gradient. It will also always be supported, as it can fall back to an XLA implementation implementation='xla'.

However, you may want to choose a specific implementation of the kernel, and fail if it is unsupported. For instance, implementation="mosaic" will try to use a Pallas:Mosaic GPU kernel if possible, and throw an exception if this is unsupported for any reason. For example, using FP64 inputs are unsupported, or older GPUs.

Evaluate the Gradient

channels, seq_len, batch_size, num_heads = (64, 2048, 32, 16)
scale = jax.random.normal(jax.random.key(0), (channels,), dtype=jnp.float32)
x = jax.random.normal(
    jax.random.key(1),
    (batch_size, seq_len, num_heads, channels),
    dtype=jnp.bfloat16,
)

out = f_grad(x, scale)

Autotuning

To get the best performance, autotune all Tokamax kernels in f_grad:

autotune_result: tokamax.AutotuningResult = tokamax.autotune(f, x, scale)

autotune_result can be used as a context-manager, using the autotuned configs for all Tokamax kernels in f_grad:

with autotune_result:
  out_autotuned = f_grad(x, scale)

To serialize and reuse the result of a potentially expensive tokamax.autotuning call:

autotune_result_json: str = autotune_result.dumps()
autotune_result = tokamax.AutotuningResult.loads(autotune_result_json)

Users can autotune their own kernels with tokamax.autotune by inheriting from the tokamax.Op class and overriding the tokamax.Op._get_autotuning_configs method to define the autotuning search-space.

Note that autotuning is fundamentally non-deterministic: measuring kernel execution times is noisy. As different configs chosen during autotuning can lead to different numerics, this is a potential source of numerical non-determinism. Serializing and reusing fixed autotuning results is a way to ensure the same numerics across sessions.

Serialization

Kernels can be serialized to StableHLO. Kernel calls are JAX custom calls, which are by default banned in jax.export, requiring the use of tokamax.DISABLE_JAX_EXPORT_CHECKS to allow all Tokamax kernels to be exported:

from jax import export

f_grad_exported = export.export(f_grad, disabled_checks=tokamax.DISABLE_JAX_EXPORT_CHECKS)(
    jax.ShapeDtypeStruct(x.shape, x.dtype),
    jax.ShapeDtypeStruct(scale.shape, scale.dtype),
)

Note that functions serialized with Tokamax kernels lose the device-independence of standard StableHLO. Tokamax makes two serialization guarantees:

  1. A deserialized function serialized on a specific device will be guaranteed to run on the exact device it was serialized for.
  2. Tokamax gives the same compatibility guarantees as JAX: 6 month backward compatibility.

Benchmarking

JAX Python overhead is often much larger than the actual accelerator kernel execution time. This means the usual approach of timing jax.block_until_ready(f_grad(x, scale)) won't be useful. Tokamax has utilities for only measuring accelerator execution time:

f_std, args = tokamax.standardize_function(f, kwargs={'x': x, 'scale': scale})
bench: tokamax.BenchmarkData = tokamax.benchmark(f_std, args)

There are different measurement techniques: for example, on GPU, there is the CUPTI profiler that can be specified via tokamax.benchmark(f_std, args, method='cupti'). This instruments the kernel and adds some a small overhead. The default method=None allows Tokamax to choose the method, and works for both TPU and GPU. Benchmark noise can be reduced by increasing the number of iterations:

tokamax.benchmark(f_std, args, iterations=10)

Disclaimer

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tokamax-0.0.12.tar.gz (428.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tokamax-0.0.12-py3-none-any.whl (606.6 kB view details)

Uploaded Python 3

File details

Details for the file tokamax-0.0.12.tar.gz.

File metadata

  • Download URL: tokamax-0.0.12.tar.gz
  • Upload date:
  • Size: 428.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for tokamax-0.0.12.tar.gz
Algorithm Hash digest
SHA256 5c7f4d2f54caed15455c7be5a769e446f760f7b730d41a4a6b916dc9bb423f97
MD5 ca275291b85673f84f791f0339ca715e
BLAKE2b-256 6653b4d83852433afe45d167c56bfc645a46e92a3ce6ad8d93668bc65c379a25

See more details on using hashes here.

Provenance

The following attestation bundles were made for tokamax-0.0.12.tar.gz:

Publisher: publish.yml on openxla/tokamax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file tokamax-0.0.12-py3-none-any.whl.

File metadata

  • Download URL: tokamax-0.0.12-py3-none-any.whl
  • Upload date:
  • Size: 606.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for tokamax-0.0.12-py3-none-any.whl
Algorithm Hash digest
SHA256 2d60e8602d88a6d5ea64e2432694044e44978d46fc5b829f338a238a9746da03
MD5 baeabbcffede19a07428e8f6ce1bb5e9
BLAKE2b-256 d35cc35d4183179b1d5eb96b287a9f7711d617ed0cc8b7066c2497581e5576f6

See more details on using hashes here.

Provenance

The following attestation bundles were made for tokamax-0.0.12-py3-none-any.whl:

Publisher: publish.yml on openxla/tokamax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page