Skip to main content

Flash Attention Jax

Project description

FlashAttention JAX

This repository provides a jax binding to https://github.com/Dao-AILab/flash-attention. To avoid depending on pytorch, since torch and jax installations often conflict, this is a fork of the official repo.

Please see Tri Dao's repo for more information about flash attention.

FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). Please cite (see below) and credit FlashAttention if you use it.

Installation

Requirements:

  • CUDA 11.8 and above.
  • Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.
  • JAX >=0.4.24. The custom sharding used for ring attention requires some somewhat advanced features.

To install: pip install flash-attn-jax will get the latest release from pypi. This gives you the cuda 12.3 build. If you want to use the cuda 11.8 build, you can install from the releases page (but according to jax's documentation, 11.8 will stop being supported for newer versions of jax).

Installing from source

Flash attention takes a long time to compile unless you have a powerful machine. But if you want to compile from source, I use cibuildwheel to compile the releases. You could do the same. Something like (for python 3.12):

git clone https://github.com/nshepperd/flash-attn-jax
cd flash-attn-jax
cibuildwheel --only cp312-manylinux_x86_64 # I think cibuildwheel needs superuser privileges on some systems because of docker reasons?

This will create a wheel in the wheelhouse directory. You can then install it with pip install wheelhouse/flash_attn_jax_0.2.0-cp312-cp312-manylinux_x86_64.whl. Or you could use setup.py to build the wheel and install it. You need cuda toolkit installed in that case.

Usage

Interface: src/flash_attn_jax/flash.py

from flash_attn_jax import flash_mha

# flash_mha : [n, l, h, d] x [n, lk, hk, d] x [n, lk, hk, d] -> [n, l, h, d]
flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1))

This supports multi-query and grouped-query attention (when hk != h). The softmax_scale is the multiplier for the softmax, defaulting to 1/sqrt(d). Set window_size to positive values for sliding window attention.

Now Supports Ring Attention

Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm (forward and backward).

os.environ["XLA_FLAGS"] = '--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_collectives=true'
#...
with Mesh(devices, axis_names=('len',)) as mesh:
        sharding = NamedSharding(mesh, P(None,'len')) # n l
        tokens = jax.device_put(tokens, sharding)
        # invoke your jax.jit'd transformer.forward

It's not entirely reliable at hiding the communication latency though, depending on the whims of the xla optimizer. I'm waiting https://github.com/google/jax/issues/20864 to be fixed, then I can make it better.

GPU support

FlashAttention-2 currently supports:

  1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing GPUs for now.
  2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
  3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.

Citation

If you use this codebase, or otherwise found our work valuable, please cite:

@inproceedings{dao2022flashattention,
  title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
  author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
  booktitle={Advances in Neural Information Processing Systems},
  year={2022}
}
@article{dao2023flashattention2,
  title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
  author={Dao, Tri},
  year={2023}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_attn_jax-0.2.2.tar.gz (2.3 MB view details)

Uploaded Source

Built Distributions

flash_attn_jax-0.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (75.4 MB view details)

Uploaded CPython 3.12 manylinux: glibc 2.17+ x86-64

flash_attn_jax-0.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (75.4 MB view details)

Uploaded CPython 3.11 manylinux: glibc 2.17+ x86-64

flash_attn_jax-0.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (75.3 MB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ x86-64

flash_attn_jax-0.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (75.3 MB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ x86-64

File details

Details for the file flash_attn_jax-0.2.2.tar.gz.

File metadata

  • Download URL: flash_attn_jax-0.2.2.tar.gz
  • Upload date:
  • Size: 2.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for flash_attn_jax-0.2.2.tar.gz
Algorithm Hash digest
SHA256 a97353523b845789fd1608eac704644f211e1604a1e5e75918a9809ad1226df5
MD5 93a06022ed36d4bba46b01a18d098339
BLAKE2b-256 505f5bd4f53756a76b11da740285b94faa5de26b1752c4fa2bd66252cedcdc78

See more details on using hashes here.

File details

Details for the file flash_attn_jax-0.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 11183d52e1e89ed8b77a1d95d3fbb5d55cb237a4186d9571de6e7dc1ed0581f8
MD5 66fbbbe8b3893b294740a36f56f92968
BLAKE2b-256 8f40aed689d67272a2e85c75d81f9c7965e2216ab6a8714b2669fbb4e06202e8

See more details on using hashes here.

File details

Details for the file flash_attn_jax-0.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 1a7c5cee3237426264344dc2ed9110e3cee286eafec2e76383d3e84545f97a58
MD5 962e0c498b7aff14da3d28446d87a727
BLAKE2b-256 43b41c7af1c6d5f02c95478e7454e378890ba9126af98c6a9a79d3e2c0d1d836

See more details on using hashes here.

File details

Details for the file flash_attn_jax-0.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 0a9ad5287219907f2cd3ace17f3a325fe58568cf62bdfa8cc1f218ae7f838b38
MD5 1be54e41eed92fc66955b862924a2d0a
BLAKE2b-256 92493051051daf1a5d83b048c958c9d56577a7d939305b2685b0529f0e1c0eaa

See more details on using hashes here.

File details

Details for the file flash_attn_jax-0.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 967e21e81a1350140c56f594c6ff4bb10b603e28e8e78854380c0b4a042c63af
MD5 78d3ec720d4a598774c3ab8893a20d8b
BLAKE2b-256 1b74a9b255b13d4f0b1671e79ffcf568fd7796037ddab4eb76b16f562fd0638a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page