FlashAttention-3 JAX binding using XLA FFI
Project description
FlashAttention3 for JAX
This library provides JAX bindings of FlashAttention 3 with support for:
- Multi-Query Attention and Grouped-Query Attention
- Causal masking and sliding window attention
- Ring Attention
- Variable length sequences within batches
- KV cache with paged attention support
Requirements
- CUDA: 12.3 or higher
- OS: Linux
- JAX: >= 0.6.0
- Version 0.6.0+ required for the new FFI custom call API
- Note: JAX >= 0.7.1 disables mixed-precision collective permute operations which affects ring attention (cf this issue)
Note: bindings have only been tested on Hopper architecture.
Installation
From PyPI
Pre-built wheels are available for Linux x86_64 with CUDA 12.4+ and Python 3.11/3.12:
pip install flash-attn3-jax
The published wheels are built for Hopper (SM90) architecture with head dimension 128 only. For other configurations, build from source.
Building from Source
For custom configurations or other GPU architectures, build from source using uv:
# Basic build
uv build --wheel
# Parallel build
CMAKE_BUILD_PARALLEL_LEVEL=32 uv build --wheel
# Install the wheel
uv pip install dist/flash_attn3_jax_*.whl
Advanced Build Options
Advanced build options are available, see default values in the pyproject.toml:
# Target specific GPU architectures
FLASH_ATTN_CUDA_ARCHS="80;90" uv build --wheel
Quick Start
import jax.numpy as jnp
from flash_attn3_jax import flash_mha
# Inputs with dimensions (batch, seqlen, num_heads, head_dim)
q = jnp.ones((2, 1024, 32, 128), dtype=jnp.float16)
k = jnp.ones((2, 1024, 32, 128), dtype=jnp.float16)
v = jnp.ones((2, 1024, 32, 128), dtype=jnp.float16)
output = flash_mha(q, k, v)
# Causal attn
output = flash_mha(q, k, v, is_causal=True)
# Sliding window attn
output = flash_mha(q, k, v, window_size=(256, 256))
API Reference
flash_mha
Main function for flash attention with fixed length sequences.
flash_mha(
q, # (b, l, h_q, d)
k, # (b, l, h_k, d)
v, # (b, l, h_k, d)
softmax_scale=None, # default to 1/sqrt(d)
is_causal=False,
window_size=(-1, -1)
)
MQA and GQA are automatically detected when h_q != h_k (h_q must be divisible by h_k).
For example, q.shape = (b, l, 32, 128), k.shape = (b, l, 8, 128) => GQA with 4 groups
flash_mha_varlen
Flash attention for variable length sequences within a batch.
from flash_attn3_jax import flash_mha_varlen
flash_mha_varlen(
q, # (total_q, h, d)
k, # (total_k, h, d)
v, # (total_k, h, d)
cu_seqlens_q, # Cumulative sequence lengths for Q (b+1,)
cu_seqlens_k, # Cumulative sequence lengths for K (b+1,)
max_seqlen_q, # Maximum sequence length in batch for Q
max_seqlen_k, # Maximum sequence length in batch for K
softmax_scale=None,
is_causal=False,
window_size=(-1, -1)
)
Example:
# Batch of 2 sequences: lengths [512, 1024], tot 1536
q = jnp.ones((512 + 1024, 32, 128), dtype=jnp.float16)
k = jnp.ones((512 + 1024, 32, 128), dtype=jnp.float16)
v = jnp.ones((512 + 1024, 32, 128), dtype=jnp.float16)
cu_seqlens = jnp.array([0, 512, 1536], dtype=jnp.int32)
output = flash_mha_varlen(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=1024,
max_seqlen_k=1024
)
flash_mha_with_kvcache
Flash attention with KV cache supporting both contiguous and paged cache modes.
from flash_attn3_jax import flash_mha_with_kvcache
flash_mha_with_kvcache(
q, # (batch, seqlen_q, h_q, d)
k_cache, # Contiguous: (batch_cache, seqlen_cache, h_k, d)
# Paged: (num_blocks, page_size, h_k, d)
v_cache, # Same shape as k_cache
k=None, # Optional new keys: (batch, seqlen_new, h_k, d)
v=None, # Optional new values: (batch, seqlen_new, h_k, d)
cache_seqlens=None, # int or (batch,) - length of valid cache data
cache_batch_idx=None, # (batch,) - map query batch to cache batch
cache_leftpad=None, # (batch,) - left padding per sequence
page_table=None, # (batch, max_num_pages) - for paged KV cache
rotary_cos=None, # (seqlen_ro, rotary_dim/2) - rotary embeddings
rotary_sin=None, # (seqlen_ro, rotary_dim/2) - rotary embeddings
softmax_scale=None, # default to 1/sqrt(d)
is_causal=False,
window_size=(-1, -1),
num_splits=1,
return_softmax_lse=False
)
Important notes: the KV cache is NOT modified in-place (unlike PyTorch version).
Basic KV Cache Usage
import jax.numpy as jnp
from flash_attn3_jax import flash_mha_with_kvcache
batch_size = 4
query_seqlen = 1 # Single token
cache_seqlen = 2048
num_heads_q = 32
num_heads_kv = 8 # GQA with 4 groups
head_dim = 128
# Query for current token
q = jnp.ones((batch_size, query_seqlen, num_heads_q, head_dim), dtype=jnp.float16)
# Pre-filled KV cache from previous tokens
k_cache = jnp.ones((batch_size, cache_seqlen, num_heads_kv, head_dim), dtype=jnp.float16)
v_cache = jnp.ones((batch_size, cache_seqlen, num_heads_kv, head_dim), dtype=jnp.float16)
# Varying cache lengths per batch item
cache_seqlens = jnp.array([1024, 1536, 2048, 512], dtype=jnp.int32)
# Run attention
output = flash_mha_with_kvcache(
q=q,
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=cache_seqlens,
is_causal=True
)
With new Keys/Values
# Query tokens
q = jnp.ones((batch_size, 8, num_heads_q, head_dim), dtype=jnp.float16)
# New K/V to append to cache
k_new = jnp.ones((batch_size, 8, num_heads_kv, head_dim), dtype=jnp.float16)
v_new = jnp.ones((batch_size, 8, num_heads_kv, head_dim), dtype=jnp.float16)
# Current cache length before appending
cache_seqlens = jnp.array([1024, 1024, 1024, 1024], dtype=jnp.int32)
# Attention includes both cache and new K/V
output = flash_mha_with_kvcache(
q=q,
k_cache=k_cache,
v_cache=v_cache,
k=k_new,
v=v_new,
cache_seqlens=cache_seqlens,
is_causal=True
)
Paged KV Cache
page_size = 64 # Tokens per page/block
num_blocks = 1000 # Total blocks in memory pool
batch_size = 4
max_pages_per_seq = 32 # Max pages/blocks allocated to each sequence, in this case 2048 tokens per sequence
# Paged KV cache as a shared memory pool
k_cache_paged = jnp.ones((num_blocks, page_size, num_heads_kv, head_dim), dtype=jnp.float16)
v_cache_paged = jnp.ones((num_blocks, page_size, num_heads_kv, head_dim), dtype=jnp.float16)
# Page table maps each sequence to its allocated blocks (shape: (batch, max_pages_per_seq))
page_table = jnp.array([
[0, 1, 2, 3, ...], # Sequence 0 uses blocks 0,1,2,3,...
[10, 15, 20, 25, ...], # Sequence 1 uses blocks 10,15,20,25,...
[5, 6, 7, 8, ...], # Sequence 2 uses blocks 5,6,7,8,...
[30, 31, 32, 33, ...], # Sequence 3 uses blocks 30,31,32,33,...
], dtype=jnp.int32)
# Cache lengths must be divisible by page_size for paged mode
cache_seqlens = jnp.array([512, 1024, 2048, 1536], dtype=jnp.int32)
output = flash_mha_with_kvcache(
q=q,
k_cache=k_cache_paged,
v_cache=v_cache_paged,
cache_seqlens=cache_seqlens,
page_table=page_table,
is_causal=True
)
Paged cache requirements:
page_tablemust have shape(batch, max_num_pages_per_seq)- Each page table entry contains the block index in the memory pool
cache_seqlensshould be divisible bypage_size
Advanced Features
Batch index mapping, map queries to different cache slots:
# 3 queries but 5 cache slots
cache_batch_idx = jnp.array([0, 2, 4], dtype=jnp.int32)
output = flash_mha_with_kvcache(
q=q, # Shape: (3, seqlen_q, h_q, d)
k_cache=k_cache, # Shape: (5, seqlen_cache, h_k, d)
v_cache=v_cache,
cache_batch_idx=cache_batch_idx,
cache_seqlens=cache_seqlens
)
Ring Attention
When flash_mha detects that inputs are sharded along the sequence dimension, it automatically dispatches to ring attention:
import jax
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from flash_attn3_jax import flash_mha
devices = mesh_utils.create_device_mesh((8,))
mesh = Mesh(devices, axis_names=('seq',))
# Shard along sequence dimension
sharding = NamedSharding(mesh, P(None, 'seq', None, None))
q_sharded = jax.device_put(q, sharding)
k_sharded = jax.device_put(k, sharding)
v_sharded = jax.device_put(v, sharding)
output = flash_mha(q_sharded, k_sharded, v_sharded)
For optimal performance, enable XLA's latency-hiding scheduler to overlap communication with computation:
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_enable_latency_hiding_scheduler=true"
# Must be set before importing JAX
import jax
Benchmarks
Forward Pass Comparison
| Batch | SeqLen | Native (Flax) | Flash Attn 2 | Flash Attn 3 | FA3 vs Flax | FA2 vs Flax | |||
|---|---|---|---|---|---|---|---|---|---|
| Time (ms) | TFLOPS | Time (ms) | TFLOPS | Time (ms) | TFLOPS | Speedup | Speedup | ||
| 32 | 512 | 1.28 | 53.8 | 0.55 | 124.9 | 0.48 | 142.9 | 2.65x | 2.32x |
| 16 | 1,024 | 1.85 | 74.4 | 0.71 | 194.0 | 0.57 | 239.1 | 3.21x | 2.64x |
| 8 | 2,048 | 3.31 | 83.0 | 1.03 | 266.0 | 0.77 | 355.0 | 4.28x | 3.23x |
| 4 | 4,096 | 5.70 | 96.4 | 1.72 | 319.0 | 1.17 | 468.9 | 4.87x | 3.33x |
| 2 | 8,192 | 9.62 | 114.3 | 3.06 | 359.6 | 1.96 | 561.6 | 4.91x | 3.15x |
| 1 | 16,384 | 20.07 | 109.5 | 5.78 | 380.6 | 3.57 | 616.2 | 5.63x | 3.48x |
Backward Pass Comparison
| Batch | SeqLen | Native (Flax) | Flash Attn 2 | Flash Attn 3 | FA3 vs Flax | FA2 vs Flax | |||
|---|---|---|---|---|---|---|---|---|---|
| Time (ms) | TFLOPS | Time (ms) | TFLOPS | Time (ms) | TFLOPS | Speedup | Speedup | ||
| 32 | 512 | 3.16 | 76.2 | 2.71 | 88.9 | 2.22 | 108.5 | 1.42x | 1.17x |
| 16 | 1,024 | 4.52 | 106.3 | 3.40 | 141.6 | 2.47 | 195.0 | 1.83x | 1.33x |
| 8 | 2,048 | 7.61 | 126.5 | 4.72 | 203.6 | 2.88 | 333.6 | 2.64x | 1.60x |
| 4 | 4,096 | 13.11 | 146.8 | 7.48 | 257.3 | 3.87 | 497.8 | 3.39x | 1.76x |
| 2 | 8,192 | 23.34 | 164.9 | 13.05 | 294.9 | 6.52 | 590.0 | 3.58x | 1.79x |
| 1 | 16,384 | 46.05 | 167.2 | 23.92 | 321.8 | 11.99 | 641.8 | 3.84x | 1.93x |
Log2 is applied for sequence length and time axis to ease seing the trends across all configuations equally.
Credits
This project is based on the official FlashAttention implementation FlashAttention repository. It is also heavily inspired by FlashAttention 2 implementation for JAX.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distributions
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file flash_attn3_jax-0.2.3-cp312-cp312-manylinux_2_28_x86_64.whl.
File metadata
- Download URL: flash_attn3_jax-0.2.3-cp312-cp312-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 23.0 MB
- Tags: CPython 3.12, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5bfa0c5f41ee6146d68d491311de39d228c805660be41124605e4c7ba3eb25d3
|
|
| MD5 |
285c9f9b83a8d5cf505d8336b5b01a50
|
|
| BLAKE2b-256 |
916541640ec21b679e0db0cabd46b462647135487de26b105ab394fa8661482d
|
Provenance
The following attestation bundles were made for flash_attn3_jax-0.2.3-cp312-cp312-manylinux_2_28_x86_64.whl:
Publisher:
publish.yml on kyutai-labs/flash-attn3-jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
flash_attn3_jax-0.2.3-cp312-cp312-manylinux_2_28_x86_64.whl -
Subject digest:
5bfa0c5f41ee6146d68d491311de39d228c805660be41124605e4c7ba3eb25d3 - Sigstore transparency entry: 1069752520
- Sigstore integration time:
-
Permalink:
kyutai-labs/flash-attn3-jax@dccfc1650e47fdc7e254db27cf605154f386a67a -
Branch / Tag:
refs/tags/v0.2.3 - Owner: https://github.com/kyutai-labs
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@dccfc1650e47fdc7e254db27cf605154f386a67a -
Trigger Event:
push
-
Statement type:
File details
Details for the file flash_attn3_jax-0.2.3-cp311-cp311-manylinux_2_28_x86_64.whl.
File metadata
- Download URL: flash_attn3_jax-0.2.3-cp311-cp311-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 23.0 MB
- Tags: CPython 3.11, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
02c9c6e61a5de8923ee6b049caf886326ebb7c14cef2e8705f6caf98ad42de33
|
|
| MD5 |
5fb297fa5a03f1edc8d827e7a4993d15
|
|
| BLAKE2b-256 |
6a7ec63ceac767dad8b38ea037131e4317f0348e5891c9b7e42d697cbe89cff2
|
Provenance
The following attestation bundles were made for flash_attn3_jax-0.2.3-cp311-cp311-manylinux_2_28_x86_64.whl:
Publisher:
publish.yml on kyutai-labs/flash-attn3-jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
flash_attn3_jax-0.2.3-cp311-cp311-manylinux_2_28_x86_64.whl -
Subject digest:
02c9c6e61a5de8923ee6b049caf886326ebb7c14cef2e8705f6caf98ad42de33 - Sigstore transparency entry: 1069752526
- Sigstore integration time:
-
Permalink:
kyutai-labs/flash-attn3-jax@dccfc1650e47fdc7e254db27cf605154f386a67a -
Branch / Tag:
refs/tags/v0.2.3 - Owner: https://github.com/kyutai-labs
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@dccfc1650e47fdc7e254db27cf605154f386a67a -
Trigger Event:
push
-
Statement type: