Skip to main content

Flash Attention port for JAX

Project description

FlashAttention JAX

This repository provides a jax binding to https://github.com/Dao-AILab/flash-attention. To avoid depending on pytorch, since torch and jax installations often conflict, this is a fork of the official repo.

Please see Tri Dao's repo for more information about flash attention. Also check there for how to cite the authors if you used flash attention in your work.

FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). Please cite (see below) and credit FlashAttention if you use it.

Installation

Requirements:

  • CUDA 12.8 and above.
  • Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.
  • JAX >= 0.5.*. The custom call api changed in this version.

To install: pip install flash-attn-jax will get the latest release from pypi. This gives you the cuda 12.8 build. CUDA 11 isn't supported any more (since jax stopped supporting it).

Installing from source

Flash attention takes a long time to compile unless you have a powerful machine. But if you want to compile from source, I use cibuildwheel to compile the releases. You could do the same. Something like (for python 3.12):

git clone https://github.com/nshepperd/flash-attn-jax
cd flash-attn-jax
cibuildwheel --only cp312-manylinux_x86_64 # I think cibuildwheel needs superuser privileges on some systems because of docker reasons?

This will create a wheel in the wheelhouse directory. You can then install it with pip install wheelhouse/flash_attn_jax_*.whl. Or you could build it without docker using uv build --wheel. You need cuda installed in that case.

Usage

Interface: src/flash_attn_jax/flash.py

from flash_attn_jax import flash_mha

# flash_mha : [n, l, h, d] x [n, lk, hk, d] x [n, lk, hk, d] -> [n, l, h, d]
flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1))

This supports multi-query and grouped-query attention (when hk != h). The softmax_scale is the multiplier for the softmax, defaulting to 1/sqrt(d). Set window_size to positive values for sliding window attention.

Now Supports Ring Attention

Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm (forward and backward).

os.environ["XLA_FLAGS"] = '--xla_gpu_enable_latency_hiding_scheduler=true'
#...
with Mesh(devices, axis_names=('len',)) as mesh:
        sharding = NamedSharding(mesh, P(None,'len')) # n l
        tokens = jax.device_put(tokens, sharding)
        # invoke your jax.jit'd transformer.forward

The latency hiding seems to be reliable now that some bugs have been fixed, as long as you enable the latency hiding scheduler as above.

GPU support

FlashAttention-2 currently supports:

  1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing GPUs for now.
  2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
  3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

flash_attn_jax-0.4.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (47.7 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

flash_attn_jax-0.4.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (47.7 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

flash_attn_jax-0.4.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (47.7 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

File details

Details for the file flash_attn_jax-0.4.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.4.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 3c61824885e0c4300fff04a1e9305075fef31b464f2ddd623e62bd9eb888e439
MD5 f2aa91b9f6484cb00afe3718f64b763c
BLAKE2b-256 f8ecc12d3d67859bdd1cab33c26042cfc6bb12a48a5a4c0c32bf889d2479d4e4

See more details on using hashes here.

File details

Details for the file flash_attn_jax-0.4.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.4.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 4f8546fd4cc045e2e640157da60b73872c5423e008ab83175d1d456657a36285
MD5 e74b5c398f1feeb3e0c7a9f0e85647f7
BLAKE2b-256 1a007cf77d0d541b0e2a44c18dc3c6dc820cbbecdd4496d662f5166dfad4eb16

See more details on using hashes here.

File details

Details for the file flash_attn_jax-0.4.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.4.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 61c9ad654eb08347b066a518ab724b2ef9c84881be14fb6fac128a8dbcc2471f
MD5 c04e79290fbf07d20eddb53d8268ddaf
BLAKE2b-256 3a40f68a221438f8571166d707834dd20d4caa0a8117c82ee93f9c40b499c3ef

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page