Skip to main content

A Pallas Custom Kernel Library.

Project description

Tokamax

CI PyPI version Static Badge

Tokamax is a library of custom accelerator kernels, supporting both NVIDIA GPUs and Google TPUs. Tokamax provides state-of-the-art custom kernel implementations built on top of JAX and Pallas.

Tokamax also provides tooling for users to build and autotune their own custom accelerator kernels.

Status

Tokamax is still heavily under development. Incomplete features and API changes are to be expected.

We currently support the following GPU kernels:

And the following for both GPU and TPU:

And the following TPU kernels:

Installation

The latest Tokamax PyPI release:

pip install -U tokamax

The latest bleeding edge version from Github, with no stability guarantees:

pip install git+https://github.com/openxla/tokamax.git

Using Tokamax

Consider a function containing Tokamax functions running on an H100 GPU:

import jax
import jax.numpy as jnp
import tokamax

def loss(x, scale):
  x = tokamax.layer_norm(
      x, scale=scale, offset=None, implementation="triton"
  )
  x = tokamax.dot_product_attention(x, x, x, implementation="xla_chunked")
  x = tokamax.layer_norm(x, scale=scale, offset=None, implementation=None)
  x = tokamax.dot_product_attention(x, x, x, implementation="mosaic")
  return jnp.sum(x)

f_grad = jax.jit(jax.grad(loss))

With implementation=None, Tokamax is allowed to select the best implementation for each kernel shape. It is even allowed to choose different implementations for the forward pass and gradient. It will also always be supported, as it can fall back to an XLA implementation implementation='xla'.

However, you may want to choose a specific implementation of the kernel, and fail if it is unsupported. For instance, implementation="mosaic" will try to use a Pallas:Mosaic GPU kernel if possible, and throw an exception if this is unsupported for any reason. For example, using FP64 inputs are unsupported, or older GPUs.

Evaluate the Gradient

channels, seq_len, batch_size, num_heads = (64, 2048, 32, 16)
scale = jax.random.normal(jax.random.key(0), (channels,), dtype=jnp.float32)
x = jax.random.normal(
    jax.random.key(1),
    (batch_size, seq_len, num_heads, channels),
    dtype=jnp.bfloat16,
)

out = f_grad(x, scale)

Autotuning

To get the best performance, autotune all Tokamax kernels in f_grad:

autotune_result: tokamax.AutotuningResult = tokamax.autotune(f, x, scale)

autotune_result can be used as a context-manager, using the autotuned configs for all Tokamax kernels in f_grad:

with autotune_result:
  out_autotuned = f_grad(x, scale)

To serialize and reuse the result of a potentially expensive tokamax.autotuning call:

autotune_result_json: str = autotune_result.dumps()
autotune_result = tokamax.AutotuningResult.loads(autotune_result_json)

Users can autotune their own kernels with tokamax.autotune by inheriting from the tokamax.Op class and overriding the tokamax.Op._get_autotuning_configs method to define the autotuning search-space.

Note that autotuning is fundamentally non-deterministic: measuring kernel execution times is noisy. As different configs chosen during autotuning can lead to different numerics, this is a potential source of numerical non-determinism. Serializing and reusing fixed autotuning results is a way to ensure the same numerics across sessions.

Serialization

Kernels can be serialized to StableHLO. Kernel calls are JAX custom calls, which are by default banned in jax.export, requiring the use of tokamax.DISABLE_JAX_EXPORT_CHECKS to allow all Tokamax kernels to be exported:

from jax import export

f_grad_exported = export.export(f_grad, disabled_checks=tokamax.DISABLE_JAX_EXPORT_CHECKS)(
    jax.ShapeDtypeStruct(x.shape, x.dtype),
    jax.ShapeDtypeStruct(scale.shape, scale.dtype),
)

Note that functions serialized with Tokamax kernels lose the device-independence of standard StableHLO. Tokamax makes two serialization guarantees:

  1. A deserialized function serialized on a specific device will be guaranteed to run on the exact device it was serialized for.
  2. Tokamax gives the same compatibility guarantees as JAX: 6 month backward compatibility.

Benchmarking

JAX Python overhead is often much larger than the actual accelerator kernel execution time. This means the usual approach of timing jax.block_until_ready(f_grad(x, scale)) won't be useful. Tokamax has utilities for only measuring accelerator execution time:

f_std, args = tokamax.benchmarking.standardize_function(f, kwargs={'x': x, 'scale': scale})
run = tokamax.benchmarking.compile_benchmark(f_std, args)
bench: tokamax.benchmarking.BenchmarkData = run(args)

There are different measurement techniques: for example, on GPU, there is the CUPTI profiler that can be specified via run(args, method='cupti'). This instruments the kernel and adds some a small overhead. The default run(args, method=None) allows Tokamax to choose the method, and works for both TPU and GPU. Benchmark noise can be reduced by increasing the number of iterations run(args, iterations=10).

Disclaimer

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tokamax-0.0.5.tar.gz (243.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tokamax-0.0.5-py3-none-any.whl (362.6 kB view details)

Uploaded Python 3

File details

Details for the file tokamax-0.0.5.tar.gz.

File metadata

  • Download URL: tokamax-0.0.5.tar.gz
  • Upload date:
  • Size: 243.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for tokamax-0.0.5.tar.gz
Algorithm Hash digest
SHA256 aab0e44d1695b886e52f4526cef6b44a84a14ecb76e59e1eb9e20aef72094043
MD5 774b794780bd7f2e146d538974f59d9c
BLAKE2b-256 b629333f67d660171f5ca13736afa4d73c8513355d6e31f1d465f789c0c57256

See more details on using hashes here.

Provenance

The following attestation bundles were made for tokamax-0.0.5.tar.gz:

Publisher: publish.yml on openxla/tokamax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file tokamax-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: tokamax-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 362.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for tokamax-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 b379bbd1863c9abcbb7ca9a678c2ec2dcceeb574b31fbc29b7ed3b25480e0aeb
MD5 1d5e61006e56dadd8747a2c3dd6fca5a
BLAKE2b-256 567537c373fca6b3e48eb74002ac2cec7f94886260f78eddd1f25ad12f31bdbe

See more details on using hashes here.

Provenance

The following attestation bundles were made for tokamax-0.0.5-py3-none-any.whl:

Publisher: publish.yml on openxla/tokamax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page