Skip to main content

A Pallas Custom Kernel Library.

Project description

Tokamax

CI PyPI version Static Badge

Tokamax is a library of custom accelerator kernels, supporting both NVIDIA GPUs and Google TPUs. Tokamax provides state-of-the-art custom kernel implementations built on top of JAX and Pallas.

Tokamax also provides tooling for users to build and autotune their own custom accelerator kernels.

Status

Tokamax is still heavily under development. Incomplete features and API changes are to be expected.

We currently support the following GPU kernels:

And the following for both GPU and TPU:

And the following TPU kernels:

Installation

The latest Tokamax PyPI release:

pip install -U tokamax

The latest bleeding edge version from Github, with no stability guarantees:

pip install git+https://github.com/openxla/tokamax.git

Using Tokamax

Consider a function containing Tokamax functions running on an H100 GPU:

import jax
import jax.numpy as jnp
import tokamax

def loss(x, scale):
  x = tokamax.layer_norm(
      x, scale=scale, offset=None, implementation="triton"
  )
  x = tokamax.dot_product_attention(x, x, x, implementation="xla_chunked")
  x = tokamax.layer_norm(x, scale=scale, offset=None, implementation=None)
  x = tokamax.dot_product_attention(x, x, x, implementation="mosaic")
  return jnp.sum(x)

f_grad = jax.jit(jax.grad(loss))

With implementation=None, Tokamax is allowed to select the best implementation for each kernel shape. It is even allowed to choose different implementations for the forward pass and gradient. It will also always be supported, as it can fall back to an XLA implementation implementation='xla'.

However, you may want to choose a specific implementation of the kernel, and fail if it is unsupported. For instance, implementation="mosaic" will try to use a Pallas:Mosaic GPU kernel if possible, and throw an exception if this is unsupported for any reason. For example, using FP64 inputs are unsupported, or older GPUs.

Evaluate the Gradient

channels, seq_len, batch_size, num_heads = (64, 2048, 32, 16)
scale = jax.random.normal(jax.random.key(0), (channels,), dtype=jnp.float32)
x = jax.random.normal(
    jax.random.key(1),
    (batch_size, seq_len, num_heads, channels),
    dtype=jnp.bfloat16,
)

out = f_grad(x, scale)

Autotuning

To get the best performance, autotune all Tokamax kernels in f_grad:

autotune_result: tokamax.AutotuningResult = tokamax.autotune(f, x, scale)

autotune_result can be used as a context-manager, using the autotuned configs for all Tokamax kernels in f_grad:

with autotune_result:
  out_autotuned = f_grad(x, scale)

To serialize and reuse the result of a potentially expensive tokamax.autotuning call:

autotune_result_json: str = autotune_result.dumps()
autotune_result = tokamax.AutotuningResult.loads(autotune_result_json)

Users can autotune their own kernels with tokamax.autotune by inheriting from the tokamax.Op class and overriding the tokamax.Op._get_autotuning_configs method to define the autotuning search-space.

Note that autotuning is fundamentally non-deterministic: measuring kernel execution times is noisy. As different configs chosen during autotuning can lead to different numerics, this is a potential source of numerical non-determinism. Serializing and reusing fixed autotuning results is a way to ensure the same numerics across sessions.

Serialization

Kernels can be serialized to StableHLO. Kernel calls are JAX custom calls, which are by default banned in jax.export, requiring the use of tokamax.DISABLE_JAX_EXPORT_CHECKS to allow all Tokamax kernels to be exported:

from jax import export

f_grad_exported = export.export(f_grad, disabled_checks=tokamax.DISABLE_JAX_EXPORT_CHECKS)(
    jax.ShapeDtypeStruct(x.shape, x.dtype),
    jax.ShapeDtypeStruct(scale.shape, scale.dtype),
)

Note that functions serialized with Tokamax kernels lose the device-independence of standard StableHLO. Tokamax makes two serialization guarantees:

  1. A deserialized function serialized on a specific device will be guaranteed to run on the exact device it was serialized for.
  2. Tokamax gives the same compatibility guarantees as JAX: 6 month backward compatibility.

Benchmarking

JAX Python overhead is often much larger than the actual accelerator kernel execution time. This means the usual approach of timing jax.block_until_ready(f_grad(x, scale)) won't be useful. Tokamax has utilities for only measuring accelerator execution time:

f_std, args = tokamax.benchmarking.standardize_function(f, kwargs={'x': x, 'scale': scale})
run = tokamax.benchmarking.compile_benchmark(f_std, args)
bench: tokamax.benchmarking.BenchmarkData = run(args)

There are different measurement techniques: for example, on GPU, there is the CUPTI profiler that can be specified via run(args, method='cupti'). This instruments the kernel and adds some a small overhead. The default run(args, method=None) allows Tokamax to choose the method, and works for both TPU and GPU. Benchmark noise can be reduced by increasing the number of iterations run(args, iterations=10).

Disclaimer

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tokamax-0.0.10.tar.gz (315.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tokamax-0.0.10-py3-none-any.whl (467.1 kB view details)

Uploaded Python 3

File details

Details for the file tokamax-0.0.10.tar.gz.

File metadata

  • Download URL: tokamax-0.0.10.tar.gz
  • Upload date:
  • Size: 315.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for tokamax-0.0.10.tar.gz
Algorithm Hash digest
SHA256 2f4031aef31f4696b95dfbfd4d8ed07e820fb60652faeb88db21f04d1c037779
MD5 a756fe710b3e10b0bb75feb614e462a4
BLAKE2b-256 3ce61c12caf5557377a9f56ad39194dd1d2a576c9e430c3508303c77e3aacdc1

See more details on using hashes here.

Provenance

The following attestation bundles were made for tokamax-0.0.10.tar.gz:

Publisher: publish.yml on openxla/tokamax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file tokamax-0.0.10-py3-none-any.whl.

File metadata

  • Download URL: tokamax-0.0.10-py3-none-any.whl
  • Upload date:
  • Size: 467.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for tokamax-0.0.10-py3-none-any.whl
Algorithm Hash digest
SHA256 f678507e206675c9ab62cc29bdc8ed81030285543ad784af03a3ba64df9552a0
MD5 4e6042b08e064582004c8d363d111246
BLAKE2b-256 ef90233a585c792f2ff95b641d545bf6d7dfe5fb4c0fecf5a6a57bd7cc84d046

See more details on using hashes here.

Provenance

The following attestation bundles were made for tokamax-0.0.10-py3-none-any.whl:

Publisher: publish.yml on openxla/tokamax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page