Skip to main content

Flash Attention port for JAX

Project description

FlashAttention JAX

This repository provides a jax binding to https://github.com/Dao-AILab/flash-attention. To avoid depending on pytorch, since torch and jax installations often conflict, this is a fork of the official repo.

Please see Tri Dao's repo for more information about flash attention. Also check there for how to cite the authors if you used flash attention in your work.

FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). Please cite (see below) and credit FlashAttention if you use it.

Installation

Requirements:

  • CUDA 12.3 and above.
  • Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.
  • JAX >= 0.5.*. The custom call api changed in this version.

To install: pip install flash-attn-jax will get the latest release from pypi. This gives you the cuda 12.8 build. CUDA 11 isn't supported any more (since jax stopped supporting it).

Installing from source

Flash attention takes a long time to compile unless you have a powerful machine. But if you want to compile from source, I use cibuildwheel to compile the releases. You could do the same. Something like (for python 3.12):

git clone https://github.com/nshepperd/flash-attn-jax
cd flash-attn-jax
cibuildwheel --only cp312-manylinux_x86_64 # I think cibuildwheel needs superuser privileges on some systems because of docker reasons?

This will create a wheel in the wheelhouse directory. You can then install it with pip install wheelhouse/flash_attn_jax_*.whl. Or you could build it without docker using uv build --wheel. You need cuda installed in that case.

Usage

Interface: src/flash_attn_jax/flash.py

from flash_attn_jax import flash_mha

# flash_mha : [n, l, h, d] x [n, lk, hk, d] x [n, lk, hk, d] -> [n, l, h, d]
flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1))

This supports multi-query and grouped-query attention (when hk != h). The softmax_scale is the multiplier for the softmax, defaulting to 1/sqrt(d). Set window_size to positive values for sliding window attention.

Now Supports Ring Attention

Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm (forward and backward).

os.environ["XLA_FLAGS"] = '--xla_gpu_enable_latency_hiding_scheduler=true'
#...
with Mesh(devices, axis_names=('len',)) as mesh:
        sharding = NamedSharding(mesh, P(None,'len')) # n l
        tokens = jax.device_put(tokens, sharding)
        # invoke your jax.jit'd transformer.forward

The latency hiding seems to be reliable now that some bugs have been fixed, as long as you enable the latency hiding scheduler as above.

GPU support

FlashAttention-2 currently supports:

  1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing GPUs for now.
  2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
  3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_attn_jax-0.6.2.tar.gz (23.3 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

flash_attn_jax-0.6.2-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (77.2 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

flash_attn_jax-0.6.2-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (77.2 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

flash_attn_jax-0.6.2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (77.2 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

flash_attn_jax-0.6.2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (77.2 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

File details

Details for the file flash_attn_jax-0.6.2.tar.gz.

File metadata

  • Download URL: flash_attn_jax-0.6.2.tar.gz
  • Upload date:
  • Size: 23.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for flash_attn_jax-0.6.2.tar.gz
Algorithm Hash digest
SHA256 15a260db6692804d7a2995eba35737fd324e73320a1a9742c121dc64b9969485
MD5 1689e8545a13e0aa3a62c631f189ca84
BLAKE2b-256 9aff44db680d6de147102176b46a3922ad136cc19476dacbc29df7d727cc0c01

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_attn_jax-0.6.2.tar.gz:

Publisher: publish.yml on nshepperd/flash_attn_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_attn_jax-0.6.2-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.6.2-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 b040f1635ed831bf070ca6a9d5f19aeae3d69a988a1815cb198a366ca76be71f
MD5 a3b7d151c56f8a436678e232be5685bf
BLAKE2b-256 cccb70cf98f22fe11a81fdf635f666f586300e5f6a09456a0de73a8339182486

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_attn_jax-0.6.2-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: publish.yml on nshepperd/flash_attn_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_attn_jax-0.6.2-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.6.2-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 cd3ab0a53254ec3e13662b9258f4729e07e8cabe11e5f0bf4e9b37b0c2e70169
MD5 ccbfdf423c30ca94daa1bf002f098b8d
BLAKE2b-256 4da784e6ca7947d7264c0e762b4dd4ad3d14ab9e79ec5da6859b326983e788c9

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_attn_jax-0.6.2-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: publish.yml on nshepperd/flash_attn_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_attn_jax-0.6.2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.6.2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 b260d07d2747c8cefd32fb123e5daf5da922474849f8aa2680add001b5b1bb8d
MD5 ef27191be83580010411b5c73860c300
BLAKE2b-256 6dc29383d1d92bcf3ff7e1b02a51b802ac6c066c4085608e9d79725d4f4890a7

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_attn_jax-0.6.2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: publish.yml on nshepperd/flash_attn_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_attn_jax-0.6.2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.6.2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 e2619cc29ca3c083e40191cea3400c97f2a88e705dac95e48588c68b5a81b7b5
MD5 87d1f068d576c766cdb9c77689636e00
BLAKE2b-256 e76acaa42f3a382b70fc9e1ee0b142ead06e4417ce9c8148756f0b7146bf54a3

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_attn_jax-0.6.2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: publish.yml on nshepperd/flash_attn_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page