Skip to main content

Flash Attention Jax

Project description

FlashAttention JAX

This repository provides a jax binding to https://github.com/Dao-AILab/flash-attention. To avoid depending on pytorch, since torch and jax installations often conflict, this is a fork of the official repo.

Please see Tri Dao's repo for more information about flash attention.

Usage

FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). Please cite (see below) and credit FlashAttention if you use it.

Installation and features

Requirements:

  • CUDA 11.8 and above.
  • Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.
  • JAX >=0.4.24. The custom sharding used for ring attention requires some somewhat advanced features.

To install: For now, download the appropriate release from the releases page and install it with pip.

Interface: src/flash_attn_jax/flash.py

from flash_attn_jax import flash_mha

flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1))

Accepts q,k,v with shape [n, l, h, d], and returns [n, l, h, d]. softmax_scale is the multiplier for the softmax, defaulting to 1/sqrt(d). Set window_size to positive values for sliding window attention.

Now Supports Ring Attention

Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm:

with Mesh(devices, axis_names=('len',)) as mesh:
        sharding = NamedSharding(mesh, P(None,'len',None)) # n l d
        tokens = jax.device_put(tokens, sharding)
        # invoke your jax.jit'd transformer.forward

FlashAttention-2 currently supports:

  1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing GPUs for now.
  2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
  3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.

Citation

If you use this codebase, or otherwise found our work valuable, please cite:

@inproceedings{dao2022flashattention,
  title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
  author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
  booktitle={Advances in Neural Information Processing Systems},
  year={2022}
}
@article{dao2023flashattention2,
  title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
  author={Dao, Tri},
  year={2023}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_attn_jax-0.2.0.tar.gz (2.3 MB view details)

Uploaded Source

Built Distributions

flash_attn_jax-0.2.0-cp312-cp312-manylinux_2_28_x86_64.whl (62.6 MB view details)

Uploaded CPython 3.12 manylinux: glibc 2.28+ x86-64

flash_attn_jax-0.2.0-cp311-cp311-manylinux_2_28_x86_64.whl (62.6 MB view details)

Uploaded CPython 3.11 manylinux: glibc 2.28+ x86-64

flash_attn_jax-0.2.0-cp310-cp310-manylinux_2_28_x86_64.whl (62.6 MB view details)

Uploaded CPython 3.10 manylinux: glibc 2.28+ x86-64

flash_attn_jax-0.2.0-cp39-cp39-manylinux_2_28_x86_64.whl (62.6 MB view details)

Uploaded CPython 3.9 manylinux: glibc 2.28+ x86-64

File details

Details for the file flash_attn_jax-0.2.0.tar.gz.

File metadata

  • Download URL: flash_attn_jax-0.2.0.tar.gz
  • Upload date:
  • Size: 2.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.8

File hashes

Hashes for flash_attn_jax-0.2.0.tar.gz
Algorithm Hash digest
SHA256 bc06fe46d2550788bcd484177dc957835441754e08511b4b97204ab0043825ae
MD5 c9c5b3834ba3f55f7e71fcf35711574b
BLAKE2b-256 127cabfaec2502074c43eae1cc912f8ba921fa55723b4ead7aeb7b0e26f29c3b

See more details on using hashes here.

File details

Details for the file flash_attn_jax-0.2.0-cp312-cp312-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.2.0-cp312-cp312-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 474514e088aebd7b864e4f96fed9f7378f12a07b0b4dabc411728b4f8550b402
MD5 7d6444c334c6158b9d5264eb6badcb30
BLAKE2b-256 3f793ffa4359d38391b949955481be38b6bc04d5040f4a1f61fcc629c2a94b30

See more details on using hashes here.

File details

Details for the file flash_attn_jax-0.2.0-cp311-cp311-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.2.0-cp311-cp311-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 6f0646ed6eb1bb2eb46e752979ce3f8b6f4f485b86587d1cbd40264ccd90f025
MD5 81a5f30d42c1d39915135115384e030e
BLAKE2b-256 cc7becd41053d2e3a799172056ca32305164dddc99ee8dde4f4de44c07f10a38

See more details on using hashes here.

File details

Details for the file flash_attn_jax-0.2.0-cp310-cp310-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.2.0-cp310-cp310-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 ae5d65c63405aab077e558a0d229a57887eb4c60f767fd3907a0184b6ba77d97
MD5 d1216ba3fb174e7a7088f6545eb0183e
BLAKE2b-256 924fc5faac156c4a5c7a18e5d43a319c35e5d895861b3fd89794d27bd5afd8eb

See more details on using hashes here.

File details

Details for the file flash_attn_jax-0.2.0-cp39-cp39-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn_jax-0.2.0-cp39-cp39-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 dbb44dd165b02553f2591c2fa1a17421606fd6a45a5757666634fe8e47500076
MD5 a9ef2eab7faf26572900a1462dde07ef
BLAKE2b-256 bf162c695472b5b89b9771f92c70647222ebbdd5649c9f4fe09e2780715023b8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page