Skip to main content

JAX + OpenAI Triton integration

Project description

jax-triton

PyPI version

The jax-triton repository contains integrations between JAX and Triton.

Documentation can be found here.

This is not an officially supported Google product.

Quickstart

The main function of interest is jax_triton.triton_call for applying Triton functions to JAX arrays, including inside jax.jit-compiled functions. For example, we can define a kernel from the Triton tutorial:

import triton
import triton.language as tl


@triton.jit
def add_kernel(
    x_ptr,
    y_ptr,
    length,
    output_ptr,
    block_size: tl.constexpr,
):
  """Adds two vectors."""
  pid = tl.program_id(axis=0)
  block_start = pid * block_size
  offsets = block_start + tl.arange(0, block_size)
  mask = offsets < length
  x = tl.load(x_ptr + offsets, mask=mask)
  y = tl.load(y_ptr + offsets, mask=mask)
  output = x + y
  tl.store(output_ptr + offsets, output, mask=mask)

Then we can apply it to JAX arrays using jax_triton.triton_call:

import jax
import jax.numpy as jnp
import jax_triton as jt

def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
  out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
  block_size = 8
  return jt.triton_call(
      x,
      y,
      x.size,
      kernel=add_kernel,
      out_shape=out_shape,
      grid=(x.size // block_size,),
      block_size=block_size)

x_val = jnp.arange(8)
y_val = jnp.arange(8, 16)
print(add(x_val, y_val))
print(jax.jit(add)(x_val, y_val))

See the examples directory, especially fused_attention.py and the fused attention ipynb.

Installation

$ pip install jax-triton

You can either use a stable release of triton or a nightly release.

Make sure you have a CUDA-compatible jax installed. For example you could run:

$ pip install "jax[cuda12]"

Development

To develop jax-triton, you can clone the repo with:

$ git clone https://github.com/jax-ml/jax-triton.git

and do an editable install with:

$ cd jax-triton
$ pip install -e .

To run the jax-triton tests, you'll need pytest:

$ pip install pytest
$ pytest tests/

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_triton-0.2.0.tar.gz (62.9 kB view details)

Uploaded Source

Built Distribution

jax_triton-0.2.0-py3-none-any.whl (27.1 kB view details)

Uploaded Python 3

File details

Details for the file jax_triton-0.2.0.tar.gz.

File metadata

  • Download URL: jax_triton-0.2.0.tar.gz
  • Upload date:
  • Size: 62.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for jax_triton-0.2.0.tar.gz
Algorithm Hash digest
SHA256 ed564a5ffb9e404557dc8d296e7eb30e501da5a4d7b03408acd0837c1c618c21
MD5 34eb181dc13b85e0d9cd9fb43a64093b
BLAKE2b-256 25c68e695bc3e3c3e4be47a9a90b215ab57ebb61b3d3295f287829fdb1206361

See more details on using hashes here.

File details

Details for the file jax_triton-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: jax_triton-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 27.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for jax_triton-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6db5f6147327f2462c35becab612d415cf9ee70c96e43020c3c7ea873eb8576e
MD5 acb871f67693a6020f1e427cf879ec00
BLAKE2b-256 f6eb44b69fb10a97d921c4312d7222b0757b5879cf1b6ca0d64815661ed1a533

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page