Skip to main content

Flash Attention port for JAX

Project description

FlashAttention JAX

This repository provides a jax binding to https://github.com/Dao-AILab/flash-attention. To avoid depending on pytorch, since torch and jax installations often conflict, this is a fork of the official repo.

Please see Tri Dao's repo for more information about flash attention. Also check there for how to cite the authors if you used flash attention in your work.

FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). Please cite (see below) and credit FlashAttention if you use it.

Installation

Requirements:

  • CUDA 12.3 and above.
  • Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.
  • JAX >= 0.5.*. The custom call api changed in this version.

To install: pip install flash-attn-jax will get the latest release from pypi. This gives you the cuda 12.8 build. CUDA 11 isn't supported any more (since jax stopped supporting it).

Installing from source

Flash attention takes a long time to compile unless you have a powerful machine. But if you want to compile from source, I use cibuildwheel to compile the releases. You could do the same. Something like (for python 3.12):

git clone https://github.com/nshepperd/flash-attn-jax
cd flash-attn-jax
cibuildwheel --only cp312-manylinux_x86_64 # I think cibuildwheel needs superuser privileges on some systems because of docker reasons?

This will create a wheel in the wheelhouse directory. You can then install it with pip install wheelhouse/flash_attn_jax_*.whl. Or you could build it without docker using uv build --wheel. You need cuda installed in that case.

Usage

Interface: src/flash_attn_jax/flash.py

from flash_attn_jax import flash_mha

# flash_mha : [n, l, h, d] x [n, lk, hk, d] x [n, lk, hk, d] -> [n, l, h, d]
flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1))

This supports multi-query and grouped-query attention (when hk != h). The softmax_scale is the multiplier for the softmax, defaulting to 1/sqrt(d). Set window_size to positive values for sliding window attention.

Now Supports Ring Attention

Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm (forward and backward).

os.environ["XLA_FLAGS"] = '--xla_gpu_enable_latency_hiding_scheduler=true'
#...
with Mesh(devices, axis_names=('len',)) as mesh:
        sharding = NamedSharding(mesh, P(None,'len')) # n l
        tokens = jax.device_put(tokens, sharding)
        # invoke your jax.jit'd transformer.forward

The latency hiding seems to be reliable now that some bugs have been fixed, as long as you enable the latency hiding scheduler as above.

GPU support

FlashAttention-2 currently supports:

  1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing GPUs for now.
  2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
  3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_attn_jax-0.4.1.tar.gz (23.3 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

flash_attn_jax-0.4.1-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (47.7 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

flash_attn_jax-0.4.1-cp313-cp313-manylinux_2_24_x86_64.whl (47.7 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.24+ x86-64

flash_attn_jax-0.4.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (47.7 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

flash_attn_jax-0.4.1-cp312-cp312-manylinux_2_24_x86_64.whl (47.7 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ x86-64

flash_attn_jax-0.4.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (47.7 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

flash_attn_jax-0.4.1-cp311-cp311-manylinux_2_24_x86_64.whl (47.7 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ x86-64

flash_attn_jax-0.4.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (47.7 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

File details

Details for the file flash_attn_jax-0.4.1.tar.gz.

File metadata

  • Download URL: flash_attn_jax-0.4.1.tar.gz
  • Upload date:
  • Size: 23.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.12

File hashes

Hashes for flash_attn_jax-0.4.1.tar.gz
Algorithm Hash digest
SHA256 d925a76a407648bbcbe8761af28562bc21925b8d04170465e57d74b3076e043f
MD5 4d780f12b11fe12e1c06b553b4730d5a
BLAKE2b-256 83e0a7ed77da6a7392a4a4ff20d76934b33616c32c07807df379101c87431708

See more details on using hashes here.

File details

Details for the file flash_attn_jax-0.4.1-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.4.1-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 fa8d1c95567371a806cbe285dc4cddef88beb76c2aa569617fc4d6aee0c8961c
MD5 a7c3bdfb84194ac2d0dc7aa035182c1e
BLAKE2b-256 6adf8010519423c11476537309df27e2d031fe580cc33a8d417c22849ed44a3a

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_attn_jax-0.4.1-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: publish.yml on nshepperd/flash_attn_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_attn_jax-0.4.1-cp313-cp313-manylinux_2_24_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.4.1-cp313-cp313-manylinux_2_24_x86_64.whl
Algorithm Hash digest
SHA256 9d9b613ac38ab37876b66358d586f58d8d90677fdf8a5d3220e86abf9947779b
MD5 d97f8eb2efcabd5aaae141d9c638604c
BLAKE2b-256 7f1bcc25cede2b19a2bacfd4a0822c5d0f1f06307b4b1360d55be50f59ac2cd7

See more details on using hashes here.

File details

Details for the file flash_attn_jax-0.4.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.4.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 af2ae1317b6e2429fe02bf456ee519fd5edc81afe6f8741429aa498bd94762e5
MD5 20180beb0218f1f4ad228763f38ffa42
BLAKE2b-256 79cde17d02e70a678f3ab72ac92c714d01990e63f2a7b3c1b4d6b7fe7bd90644

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_attn_jax-0.4.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: publish.yml on nshepperd/flash_attn_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_attn_jax-0.4.1-cp312-cp312-manylinux_2_24_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.4.1-cp312-cp312-manylinux_2_24_x86_64.whl
Algorithm Hash digest
SHA256 437372a0df45a9069db9d7e4823d7de200b8df9878ea1d18b864f839faa84c8a
MD5 66bbca3f9b21ae604342574ca6b714bf
BLAKE2b-256 944e0e8f51a304c109f466672b1eaf2c23a1a9ddf5293ab79b8a8d9eed094059

See more details on using hashes here.

File details

Details for the file flash_attn_jax-0.4.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.4.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 6248f3300c2e8916bcac73391c8b013e53a74043f74689436371f3830b1accc7
MD5 71c7fc795302d156be49a175dcd276b4
BLAKE2b-256 effafff3566eafb123717cec814517940bc5ebe52a7dab43a905900e7da611f6

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_attn_jax-0.4.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: publish.yml on nshepperd/flash_attn_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_attn_jax-0.4.1-cp311-cp311-manylinux_2_24_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.4.1-cp311-cp311-manylinux_2_24_x86_64.whl
Algorithm Hash digest
SHA256 ab00cf1b9ed331de9c7e2cf00f450dfcfb7dfb0946467c8a175cd01695ae3600
MD5 bcc994903c6e360d0538a2b73a9e17cc
BLAKE2b-256 f4b8036a45ea16f3bbbe371127ed91632552adb8f53529c2604d0ed4f76db01c

See more details on using hashes here.

File details

Details for the file flash_attn_jax-0.4.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.4.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 668517a71aaeb375cefa0e593a6576ba38c7e5bcb13e16d975ecad90efa449e7
MD5 4173605b90f68617b3664c77088371b0
BLAKE2b-256 a98bd3bc2d89b73c064099d56e8b6248f35f7ff307a724c3e9f3957738b37080

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_attn_jax-0.4.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: publish.yml on nshepperd/flash_attn_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page