Skip to main content

Flash Attention port for JAX

Project description

FlashAttention JAX

This repository provides a jax binding to https://github.com/Dao-AILab/flash-attention. To avoid depending on pytorch, since torch and jax installations often conflict, this is a fork of the official repo.

Please see Tri Dao's repo for more information about flash attention. Also check there for how to cite the authors if you used flash attention in your work.

FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). Please cite (see below) and credit FlashAttention if you use it.

Installation

Requirements:

  • CUDA 12.3 and above.
  • Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.
  • JAX >= 0.5.*. The custom call api changed in this version.

To install: pip install flash-attn-jax will get the latest release from pypi. This gives you the cuda 12.8 build. CUDA 11 isn't supported any more (since jax stopped supporting it).

Installing from source

Flash attention takes a long time to compile unless you have a powerful machine. But if you want to compile from source, I use cibuildwheel to compile the releases. You could do the same. Something like (for python 3.12):

git clone https://github.com/nshepperd/flash-attn-jax
cd flash-attn-jax
cibuildwheel --only cp312-manylinux_x86_64 # I think cibuildwheel needs superuser privileges on some systems because of docker reasons?

This will create a wheel in the wheelhouse directory. You can then install it with pip install wheelhouse/flash_attn_jax_*.whl. Or you could build it without docker using uv build --wheel. You need cuda installed in that case.

Usage

Interface: src/flash_attn_jax/flash.py

from flash_attn_jax import flash_mha

# flash_mha : [n, l, h, d] x [n, lk, hk, d] x [n, lk, hk, d] -> [n, l, h, d]
flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1))

This supports multi-query and grouped-query attention (when hk != h). The softmax_scale is the multiplier for the softmax, defaulting to 1/sqrt(d). Set window_size to positive values for sliding window attention.

Now Supports Ring Attention

Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm (forward and backward).

os.environ["XLA_FLAGS"] = '--xla_gpu_enable_latency_hiding_scheduler=true'
#...
with Mesh(devices, axis_names=('len',)) as mesh:
        sharding = NamedSharding(mesh, P(None,'len')) # n l
        tokens = jax.device_put(tokens, sharding)
        # invoke your jax.jit'd transformer.forward

The latency hiding seems to be reliable now that some bugs have been fixed, as long as you enable the latency hiding scheduler as above.

GPU support

FlashAttention-2 currently supports:

  1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing GPUs for now.
  2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
  3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_attn_jax-0.5.4.tar.gz (23.4 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

flash_attn_jax-0.5.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (47.7 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

flash_attn_jax-0.5.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (47.7 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

flash_attn_jax-0.5.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (47.7 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

flash_attn_jax-0.5.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (47.7 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

File details

Details for the file flash_attn_jax-0.5.4.tar.gz.

File metadata

  • Download URL: flash_attn_jax-0.5.4.tar.gz
  • Upload date:
  • Size: 23.4 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for flash_attn_jax-0.5.4.tar.gz
Algorithm Hash digest
SHA256 3064f3d2097dfdbc98fb5b06616eded184704354cf81d9af2ed4ef0f8686a6ae
MD5 8c1e0f4eb9d20f1dab66ec16c9b7d6ad
BLAKE2b-256 1b6629084764cbc6e715c3c81c9e56f7743dc6e7d0f2402f5be6b3b5316914d7

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_attn_jax-0.5.4.tar.gz:

Publisher: publish.yml on nshepperd/flash_attn_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_attn_jax-0.5.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.5.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 0efa9ee3e9779b6378929891460e74f4b36e5751f5655a68901cbfbeb484aed3
MD5 fcf2cd460033aa104867e1296c6741ee
BLAKE2b-256 6a2efeb26999b6559adf7d662a992f531378573cfb45419a9a2f03a3e1709384

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_attn_jax-0.5.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: publish.yml on nshepperd/flash_attn_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_attn_jax-0.5.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.5.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 2eba450fc31e659c00fa7347760c6d74077cf9a5fe79d9bab9f82766996cc5a3
MD5 3d8ae6c5cded9b55fc3e8917f0b4708c
BLAKE2b-256 0526e31d0fc8f86d82fbb0a780b6eca2a94e75dd87552651e1379751792b16e3

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_attn_jax-0.5.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: publish.yml on nshepperd/flash_attn_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_attn_jax-0.5.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.5.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 c3c9d0d3405204301bcea019815b76c37834bc323a269ddb9a97da71afd51608
MD5 63dde52e3d82a095ebf70522557187f8
BLAKE2b-256 4006f95475d732fdc4639b49de8e70dbdea536e1889cf685727acc13ce9d189c

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_attn_jax-0.5.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: publish.yml on nshepperd/flash_attn_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_attn_jax-0.5.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.5.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 dc304801fd946de6a5353deba4d94732469cee7b59ec622324d58ad292b6dce2
MD5 9f09e4a6a3f35e8672474053c6bd5d55
BLAKE2b-256 9e95bcf5ff3fbb1946f9383224c21207161b505e1b3363cc62762c9bb6551169

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_attn_jax-0.5.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: publish.yml on nshepperd/flash_attn_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page