JAX + OpenAI Triton integration
Project description
jax-triton
The jax-triton
repository contains integrations between JAX and Triton.
This is not an officially supported Google product.
Installation
$ pip install jax-triton
Make sure you have a CUDA-compatible jaxlib
installed.
For example you could run:
$ pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Development
To develop jax-triton
, you can clone the repo with:
$ git clone https://github.com/jax-ml/jax-triton.git
and do an editable install with:
$ cd jax-triton
$ pip install -e .
To run the jax-triton
tests, you'll need pytest
and absl-py
:
$ pip install pytest absl-py
$ pytest tests/
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jax-triton-0.1.1.tar.gz
(10.3 kB
view details)
File details
Details for the file jax-triton-0.1.1.tar.gz
.
File metadata
- Download URL: jax-triton-0.1.1.tar.gz
- Upload date:
- Size: 10.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1d3406e81bf8a2268a4d7688a8ab7e2c10893c5939f5ba56c1d706fa4c377fa0 |
|
MD5 | ca1c7c484c67c9beed3e478766e95d34 |
|
BLAKE2b-256 | faf226cd0f387d7c0dab778a5e762acd6c8c5f4fdb5af535e8a168b488035175 |