Flash Attention Jax
Project description
FlashAttention JAX
This repository provides a jax binding to https://github.com/Dao-AILab/flash-attention. To avoid depending on pytorch, since torch and jax installations often conflict, this is a fork of the official repo.
Please see Tri Dao's repo for more information about flash attention.
Usage
FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). Please cite (see below) and credit FlashAttention if you use it.
Installation and features
Requirements:
- CUDA 11.8 and above.
- Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.
- JAX >=
0.4.24
. The custom sharding used for ring attention requires some somewhat advanced features.
To install: For now, download the appropriate release from the releases page and install it with pip.
Interface: src/flash_attn_jax/flash.py
from flash_attn_jax import flash_mha
flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1))
Accepts q,k,v with shape [n, l, h, d]
, and returns [n, l, h, d]
. softmax_scale
is the
multiplier for the softmax, defaulting to 1/sqrt(d)
. Set window_size
to positive values for sliding window attention.
Now Supports Ring Attention
Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm:
with Mesh(devices, axis_names=('len',)) as mesh:
sharding = NamedSharding(mesh, P(None,'len',None)) # n l d
tokens = jax.device_put(tokens, sharding)
# invoke your jax.jit'd transformer.forward
FlashAttention-2 currently supports:
- Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing GPUs for now.
- Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
- All head dimensions up to 256.
Head dim > 192 backward requires A100/A800 or H100/H800. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
Citation
If you use this codebase, or otherwise found our work valuable, please cite:
@inproceedings{dao2022flashattention,
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}
@article{dao2023flashattention2,
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
author={Dao, Tri},
year={2023}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
File details
Details for the file flash_attn_jax-0.2.0.tar.gz
.
File metadata
- Download URL: flash_attn_jax-0.2.0.tar.gz
- Upload date:
- Size: 2.3 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | bc06fe46d2550788bcd484177dc957835441754e08511b4b97204ab0043825ae |
|
MD5 | c9c5b3834ba3f55f7e71fcf35711574b |
|
BLAKE2b-256 | 127cabfaec2502074c43eae1cc912f8ba921fa55723b4ead7aeb7b0e26f29c3b |
File details
Details for the file flash_attn_jax-0.2.0-cp312-cp312-manylinux_2_28_x86_64.whl
.
File metadata
- Download URL: flash_attn_jax-0.2.0-cp312-cp312-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 62.6 MB
- Tags: CPython 3.12, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 474514e088aebd7b864e4f96fed9f7378f12a07b0b4dabc411728b4f8550b402 |
|
MD5 | 7d6444c334c6158b9d5264eb6badcb30 |
|
BLAKE2b-256 | 3f793ffa4359d38391b949955481be38b6bc04d5040f4a1f61fcc629c2a94b30 |
File details
Details for the file flash_attn_jax-0.2.0-cp311-cp311-manylinux_2_28_x86_64.whl
.
File metadata
- Download URL: flash_attn_jax-0.2.0-cp311-cp311-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 62.6 MB
- Tags: CPython 3.11, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6f0646ed6eb1bb2eb46e752979ce3f8b6f4f485b86587d1cbd40264ccd90f025 |
|
MD5 | 81a5f30d42c1d39915135115384e030e |
|
BLAKE2b-256 | cc7becd41053d2e3a799172056ca32305164dddc99ee8dde4f4de44c07f10a38 |
File details
Details for the file flash_attn_jax-0.2.0-cp310-cp310-manylinux_2_28_x86_64.whl
.
File metadata
- Download URL: flash_attn_jax-0.2.0-cp310-cp310-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 62.6 MB
- Tags: CPython 3.10, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ae5d65c63405aab077e558a0d229a57887eb4c60f767fd3907a0184b6ba77d97 |
|
MD5 | d1216ba3fb174e7a7088f6545eb0183e |
|
BLAKE2b-256 | 924fc5faac156c4a5c7a18e5d43a319c35e5d895861b3fd89794d27bd5afd8eb |
File details
Details for the file flash_attn_jax-0.2.0-cp39-cp39-manylinux_2_28_x86_64.whl
.
File metadata
- Download URL: flash_attn_jax-0.2.0-cp39-cp39-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 62.6 MB
- Tags: CPython 3.9, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | dbb44dd165b02553f2591c2fa1a17421606fd6a45a5757666634fe8e47500076 |
|
MD5 | a9ef2eab7faf26572900a1462dde07ef |
|
BLAKE2b-256 | bf162c695472b5b89b9771f92c70647222ebbdd5649c9f4fe09e2780715023b8 |