Skip to main content

JAX + OpenAI Triton integration

Project description

jax-triton

PyPI version

The jax-triton repository contains integrations between JAX and Triton.

Documentation can be found here.

This is not an officially supported Google product.

Quickstart

The main function of interest is jax_triton.triton_call for applying Triton functions to JAX arrays, including inside jax.jit-compiled functions. For example, we can define a kernel from the Triton tutorial:

import triton
import triton.language as tl


@triton.jit
def add_kernel(
    x_ptr,
    y_ptr,
    length,
    output_ptr,
    block_size: tl.constexpr,
):
  """Adds two vectors."""
  pid = tl.program_id(axis=0)
  block_start = pid * block_size
  offsets = block_start + tl.arange(0, block_size)
  mask = offsets < length
  x = tl.load(x_ptr + offsets, mask=mask)
  y = tl.load(y_ptr + offsets, mask=mask)
  output = x + y
  tl.store(output_ptr + offsets, output, mask=mask)

Then we can apply it to JAX arrays using jax_triton.triton_call:

import jax
import jax.numpy as jnp
import jax_triton as jt

def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
  out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
  block_size = 8
  return jt.triton_call(
      x,
      y,
      x.size,
      kernel=add_kernel,
      out_shape=out_shape,
      grid=(x.size // block_size,),
      block_size=block_size)

x_val = jnp.arange(8)
y_val = jnp.arange(8, 16)
print(add(x_val, y_val))
print(jax.jit(add)(x_val, y_val))

See the examples directory, especially fused_attention.py and the fused attention ipynb.

Installation

$ pip install jax-triton

Make sure you have a CUDA-compatible jax installed. For example you could run:

$ pip install "jax[cuda12]"

jax-triton currently requires building the latest version of triton from source.

Development

To develop jax-triton, you can clone the repo with:

$ git clone https://github.com/jax-ml/jax-triton.git

and do an editable install with:

$ cd jax-triton
$ pip install -e .

To run the jax-triton tests, you'll need pytest:

$ pip install pytest
$ pytest tests/

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_triton-0.3.1.tar.gz (16.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_triton-0.3.1-py3-none-any.whl (17.4 kB view details)

Uploaded Python 3

File details

Details for the file jax_triton-0.3.1.tar.gz.

File metadata

  • Download URL: jax_triton-0.3.1.tar.gz
  • Upload date:
  • Size: 16.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_triton-0.3.1.tar.gz
Algorithm Hash digest
SHA256 8ee1e53dfb199f5bebedbe743854f351fc30d931eede82b953c1bbeb9c22e590
MD5 eacf4cef7c1c469f7782c5e8888658f8
BLAKE2b-256 1f049a54927ffa73c2cce3751391dd7d2f33aa8663c00febcc2a69325c708ba6

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_triton-0.3.1.tar.gz:

Publisher: pypi-publish.yml on jax-ml/jax-triton

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_triton-0.3.1-py3-none-any.whl.

File metadata

  • Download URL: jax_triton-0.3.1-py3-none-any.whl
  • Upload date:
  • Size: 17.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_triton-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 c25277cb5bd454fd6a524265447f3533cf902cb1ab2b97d91c24f530c4e71024
MD5 681aaa4bc8242ec73004080a68033ac0
BLAKE2b-256 ccfed1784f79ae7a84a809fa1adc8487a399ecfab1066a1dceaf80836fdcbbf3

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_triton-0.3.1-py3-none-any.whl:

Publisher: pypi-publish.yml on jax-ml/jax-triton

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page