Skip to main content

A Pallas Custom Kernel Library.

Project description

Tokamax

CI PyPI version Static Badge

Tokamax is a library of custom accelerator kernels, supporting both NVIDIA GPUs and Google TPUs. Tokamax provides state-of-the-art custom kernel implementations built on top of JAX and Pallas.

Tokamax also provides tooling for users to build and autotune their own custom accelerator kernels.

Status

Tokamax is still heavily under development. Incomplete features and API changes are to be expected.

We currently support the following GPU kernels:

And the following for both GPU and TPU:

And the following TPU kernels:

Installation

The latest Tokamax PyPI release:

pip install -U tokamax

The latest bleeding edge version from Github, with no stability guarantees:

pip install git+https://github.com/openxla/tokamax.git

Using Tokamax

Consider a function containing Tokamax functions running on an H100 GPU:

import jax
import jax.numpy as jnp
import tokamax

def loss(x, scale):
  x = tokamax.layer_norm(
      x, scale=scale, offset=None, implementation="triton"
  )
  x = tokamax.dot_product_attention(x, x, x, implementation="xla_chunked")
  x = tokamax.layer_norm(x, scale=scale, offset=None, implementation=None)
  x = tokamax.dot_product_attention(x, x, x, implementation="mosaic")
  return jnp.sum(x)

f_grad = jax.jit(jax.grad(loss))

With implementation=None, Tokamax is allowed to select the best implementation for each kernel shape. It is even allowed to choose different implementations for the forward pass and gradient. It will also always be supported, as it can fall back to an XLA implementation implementation='xla'.

However, you may want to choose a specific implementation of the kernel, and fail if it is unsupported. For instance, implementation="mosaic" will try to use a Pallas:Mosaic GPU kernel if possible, and throw an exception if this is unsupported for any reason. For example, using FP64 inputs are unsupported, or older GPUs.

Evaluate the Gradient

channels, seq_len, batch_size, num_heads = (64, 2048, 32, 16)
scale = jax.random.normal(jax.random.key(0), (channels,), dtype=jnp.float32)
x = jax.random.normal(
    jax.random.key(1),
    (batch_size, seq_len, num_heads, channels),
    dtype=jnp.bfloat16,
)

out = f_grad(x, scale)

Autotuning

To get the best performance, autotune all Tokamax kernels in f_grad:

autotune_result: tokamax.AutotuningResult = tokamax.autotune(f, x, scale)

autotune_result can be used as a context-manager, using the autotuned configs for all Tokamax kernels in f_grad:

with autotune_result:
  out_autotuned = f_grad(x, scale)

To serialize and reuse the result of a potentially expensive tokamax.autotuning call:

autotune_result_json: str = autotune_result.dumps()
autotune_result = tokamax.AutotuningResult.loads(autotune_result_json)

Users can autotune their own kernels with tokamax.autotune by inheriting from the tokamax.Op class and overriding the tokamax.Op._get_autotuning_configs method to define the autotuning search-space.

Note that autotuning is fundamentally non-deterministic: measuring kernel execution times is noisy. As different configs chosen during autotuning can lead to different numerics, this is a potential source of numerical non-determinism. Serializing and reusing fixed autotuning results is a way to ensure the same numerics across sessions.

Serialization

Kernels can be serialized to StableHLO. Kernel calls are JAX custom calls, which are by default banned in jax.export, requiring the use of tokamax.DISABLE_JAX_EXPORT_CHECKS to allow all Tokamax kernels to be exported:

from jax import export

f_grad_exported = export.export(f_grad, disabled_checks=tokamax.DISABLE_JAX_EXPORT_CHECKS)(
    jax.ShapeDtypeStruct(x.shape, x.dtype),
    jax.ShapeDtypeStruct(scale.shape, scale.dtype),
)

Note that functions serialized with Tokamax kernels lose the device-independence of standard StableHLO. Tokamax makes two serialization guarantees:

  1. A deserialized function serialized on a specific device will be guaranteed to run on the exact device it was serialized for.
  2. Tokamax gives the same compatibility guarantees as JAX: 6 month backward compatibility.

Benchmarking

JAX Python overhead is often much larger than the actual accelerator kernel execution time. This means the usual approach of timing jax.block_until_ready(f_grad(x, scale)) won't be useful. Tokamax has utilities for only measuring accelerator execution time:

f_std, args = tokamax.benchmarking.standardize_function(f, kwargs={'x': x, 'scale': scale})
run = tokamax.benchmarking.compile_benchmark(f_std, args)
bench: tokamax.benchmarking.BenchmarkData = run(args)

There are different measurement techniques: for example, on GPU, there is the CUPTI profiler that can be specified via run(args, method='cupti'). This instruments the kernel and adds some a small overhead. The default run(args, method=None) allows Tokamax to choose the method, and works for both TPU and GPU. Benchmark noise can be reduced by increasing the number of iterations run(args, iterations=10).

Disclaimer

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tokamax-0.0.8.tar.gz (253.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tokamax-0.0.8-py3-none-any.whl (375.2 kB view details)

Uploaded Python 3

File details

Details for the file tokamax-0.0.8.tar.gz.

File metadata

  • Download URL: tokamax-0.0.8.tar.gz
  • Upload date:
  • Size: 253.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for tokamax-0.0.8.tar.gz
Algorithm Hash digest
SHA256 75987d1353c01925de0730cd460b9d7ef01d2cdfdd6ee3fed640f633fdb59e4b
MD5 1824e7b7222f21d8cb6f52387b621104
BLAKE2b-256 ee563ce59293cc7d5993f559410992e7ae8db10d72f8552ac6033e16f40e6997

See more details on using hashes here.

Provenance

The following attestation bundles were made for tokamax-0.0.8.tar.gz:

Publisher: publish.yml on openxla/tokamax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file tokamax-0.0.8-py3-none-any.whl.

File metadata

  • Download URL: tokamax-0.0.8-py3-none-any.whl
  • Upload date:
  • Size: 375.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for tokamax-0.0.8-py3-none-any.whl
Algorithm Hash digest
SHA256 694ecec368ae192de925496ad511837989d3d1705f8ad46adccd1f131542631f
MD5 93a8e52e28b227d06aaa8f6304bb5030
BLAKE2b-256 95a8e3270a48738795c950f0628547673b0beecd6953484f8b9dd6ecf93678cc

See more details on using hashes here.

Provenance

The following attestation bundles were made for tokamax-0.0.8-py3-none-any.whl:

Publisher: publish.yml on openxla/tokamax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page