Skip to main content

JAX + OpenAI Triton integration

Project description

jax-triton

The jax-triton repository contains integrations between JAX and Triton.

Installation

You may optionally use a virtualenv or can use pip install --user.

  1. Install latest Triton
$ git clone https://github.com/openai/triton.git
$ cd triton/python
$ pip install cmake
$ pip install -e .[tests]

To verify it worked, try running (from within triton/python):

$ pytest test/unit
  1. Get JAX w/ Triton support
$ git clone https://github.com/sharadmv/jax.git
$ cd jax
$ git checkout triton
$ pip install -e ".[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
$ pip install pybind11
$ cd triton
$ make # compiles our custom call
$ pip install .

We have a couple examples already written. Try running (inside of jax/triton/examples):

$ python matrix_multiplication.py

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax-triton-0.1.0.tar.gz (4.9 kB view details)

Uploaded Source

File details

Details for the file jax-triton-0.1.0.tar.gz.

File metadata

  • Download URL: jax-triton-0.1.0.tar.gz
  • Upload date:
  • Size: 4.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for jax-triton-0.1.0.tar.gz
Algorithm Hash digest
SHA256 30237639eeb79d0d391881abdf2a4231390676f8be8f54bb12b5cf7746f86e1e
MD5 3e59a493f6223fcddc0d2e12179e06fa
BLAKE2b-256 b619ffcde70c993f34feaf83db5166a997c2cb001f0d2e56debdbc62ec0d98e5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page