JAX + OpenAI Triton integration
Project description
jax-triton
The jax-triton
repository contains integrations between JAX and Triton.
Installation
You may optionally use a virtualenv
or can use pip install --user
.
- Install latest Triton
$ git clone https://github.com/openai/triton.git
$ cd triton/python
$ pip install cmake
$ pip install -e .[tests]
To verify it worked, try running (from within triton/python
):
$ pytest test/unit
- Get JAX w/ Triton support
$ git clone https://github.com/sharadmv/jax.git
$ cd jax
$ git checkout triton
$ pip install -e ".[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
$ pip install pybind11
$ cd triton
$ make # compiles our custom call
$ pip install .
We have a couple examples already written. Try running (inside of jax/triton/examples
):
$ python matrix_multiplication.py
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jax-triton-0.1.0.tar.gz
(4.9 kB
view details)
File details
Details for the file jax-triton-0.1.0.tar.gz
.
File metadata
- Download URL: jax-triton-0.1.0.tar.gz
- Upload date:
- Size: 4.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 30237639eeb79d0d391881abdf2a4231390676f8be8f54bb12b5cf7746f86e1e |
|
MD5 | 3e59a493f6223fcddc0d2e12179e06fa |
|
BLAKE2b-256 | b619ffcde70c993f34feaf83db5166a997c2cb001f0d2e56debdbc62ec0d98e5 |