JAX + OpenAI Triton integration
Project description
jax-triton
The jax-triton repository contains integrations between JAX and Triton.
This is not an officially supported Google product.
Installation
$ pip install jax-triton
Make sure you have a CUDA-compatible jaxlib installed.
For example you could run:
$ pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Development
To develop jax-triton, you can clone the repo with:
$ git clone https://github.com/jax-ml/jax-triton.git
and do an editable install with:
$ cd jax-triton
$ pip install -e .
To run the jax-triton tests, you'll need pytest and absl-py:
$ pip install pytest absl-py
$ pytest tests/
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jax-triton-0.1.1.tar.gz
(10.3 kB
view details)
File details
Details for the file jax-triton-0.1.1.tar.gz.
File metadata
- Download URL: jax-triton-0.1.1.tar.gz
- Upload date:
- Size: 10.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1d3406e81bf8a2268a4d7688a8ab7e2c10893c5939f5ba56c1d706fa4c377fa0
|
|
| MD5 |
ca1c7c484c67c9beed3e478766e95d34
|
|
| BLAKE2b-256 |
faf226cd0f387d7c0dab778a5e762acd6c8c5f4fdb5af535e8a168b488035175
|