Skip to main content

FlashAttention-3 JAX binding using XLA FFI

Project description

FlashAttention3 for JAX

This library provides JAX bindings of FlashAttention 3 with support for:

  • Multi-Query Attention and Grouped-Query Attention
  • Causal masking and sliding window attention
  • Ring Attention
  • Variable length sequences within batches
  • KV cache with paged attention support

Requirements

  • CUDA: 12.3 or higher
  • OS: Linux
  • JAX: >= 0.6.0
    • Version 0.6.0+ required for the new FFI custom call API
    • Note: JAX >= 0.7.1 disables mixed-precision collective permute operations which affects ring attention (cf this issue)

Note: bindings have only been tested on Hopper architecture.

Installation

From PyPI

Pre-built wheels are available for Linux x86_64 with CUDA 12.4+ and Python 3.11/3.12:

pip install flash-attn3-jax

The published wheels are built for Hopper (SM90) architecture with head dimension 128 only. For other configurations, build from source.

Building from Source

For custom configurations or other GPU architectures, build from source using uv:

# Basic build
uv build --wheel

# Parallel build
CMAKE_BUILD_PARALLEL_LEVEL=32 uv build --wheel

# Install the wheel
uv pip install dist/flash_attn3_jax_*.whl

Advanced Build Options

Advanced build options are available, see default values in the pyproject.toml:

# Target specific GPU architectures
FLASH_ATTN_CUDA_ARCHS="80;90" uv build --wheel

Quick Start

import jax.numpy as jnp
from flash_attn3_jax import flash_mha

# Inputs with dimensions (batch, seqlen, num_heads, head_dim)
q = jnp.ones((2, 1024, 32, 128), dtype=jnp.float16)
k = jnp.ones((2, 1024, 32, 128), dtype=jnp.float16)
v = jnp.ones((2, 1024, 32, 128), dtype=jnp.float16)

output = flash_mha(q, k, v)

# Causal attn
output = flash_mha(q, k, v, is_causal=True)

# Sliding window attn
output = flash_mha(q, k, v, window_size=(256, 256))

API Reference

flash_mha

Main function for flash attention with fixed length sequences.

flash_mha(
    q,                      # (b, l, h_q, d)
    k,                      # (b, l, h_k, d)
    v,                      # (b, l, h_k, d)
    softmax_scale=None,     # default to 1/sqrt(d)
    is_causal=False,
    window_size=(-1, -1)
)

MQA and GQA are automatically detected when h_q != h_k (h_q must be divisible by h_k). For example, q.shape = (b, l, 32, 128), k.shape = (b, l, 8, 128) => GQA with 4 groups

flash_mha_varlen

Flash attention for variable length sequences within a batch.

from flash_attn3_jax import flash_mha_varlen

flash_mha_varlen(
    q,                     # (total_q, h, d)
    k,                     # (total_k, h, d)
    v,                     # (total_k, h, d)
    cu_seqlens_q,          # Cumulative sequence lengths for Q (b+1,)
    cu_seqlens_k,          # Cumulative sequence lengths for K (b+1,)
    max_seqlen_q,          # Maximum sequence length in batch for Q
    max_seqlen_k,          # Maximum sequence length in batch for K
    softmax_scale=None,
    is_causal=False,
    window_size=(-1, -1)
)

Example:

# Batch of 2 sequences: lengths [512, 1024], tot 1536
q = jnp.ones((512 + 1024, 32, 128), dtype=jnp.float16)
k = jnp.ones((512 + 1024, 32, 128), dtype=jnp.float16)
v = jnp.ones((512 + 1024, 32, 128), dtype=jnp.float16)

cu_seqlens = jnp.array([0, 512, 1536], dtype=jnp.int32)

output = flash_mha_varlen(
    q, k, v,
    cu_seqlens_q=cu_seqlens,
    cu_seqlens_k=cu_seqlens,
    max_seqlen_q=1024,
    max_seqlen_k=1024
)

flash_mha_with_kvcache

Flash attention with KV cache supporting both contiguous and paged cache modes.

from flash_attn3_jax import flash_mha_with_kvcache

flash_mha_with_kvcache(
    q,                      # (batch, seqlen_q, h_q, d)
    k_cache,                # Contiguous: (batch_cache, seqlen_cache, h_k, d)
                            # Paged: (num_blocks, page_size, h_k, d)
    v_cache,                # Same shape as k_cache
    k=None,                 # Optional new keys: (batch, seqlen_new, h_k, d)
    v=None,                 # Optional new values: (batch, seqlen_new, h_k, d)
    cache_seqlens=None,     # int or (batch,) - length of valid cache data
    cache_batch_idx=None,   # (batch,) - map query batch to cache batch
    cache_leftpad=None,     # (batch,) - left padding per sequence
    page_table=None,        # (batch, max_num_pages) - for paged KV cache
    rotary_cos=None,        # (seqlen_ro, rotary_dim/2) - rotary embeddings
    rotary_sin=None,        # (seqlen_ro, rotary_dim/2) - rotary embeddings
    softmax_scale=None,     # default to 1/sqrt(d)
    is_causal=False,
    window_size=(-1, -1),
    num_splits=1,
    return_softmax_lse=False
)

Important notes: the KV cache is NOT modified in-place (unlike PyTorch version).

Basic KV Cache Usage

import jax.numpy as jnp
from flash_attn3_jax import flash_mha_with_kvcache

batch_size = 4
query_seqlen = 1  # Single token
cache_seqlen = 2048
num_heads_q = 32
num_heads_kv = 8  # GQA with 4 groups
head_dim = 128

# Query for current token
q = jnp.ones((batch_size, query_seqlen, num_heads_q, head_dim), dtype=jnp.float16)

# Pre-filled KV cache from previous tokens
k_cache = jnp.ones((batch_size, cache_seqlen, num_heads_kv, head_dim), dtype=jnp.float16)
v_cache = jnp.ones((batch_size, cache_seqlen, num_heads_kv, head_dim), dtype=jnp.float16)

# Varying cache lengths per batch item
cache_seqlens = jnp.array([1024, 1536, 2048, 512], dtype=jnp.int32)

# Run attention
output = flash_mha_with_kvcache(
    q=q,
    k_cache=k_cache,
    v_cache=v_cache,
    cache_seqlens=cache_seqlens,
    is_causal=True
)

With new Keys/Values

# Query tokens
q = jnp.ones((batch_size, 8, num_heads_q, head_dim), dtype=jnp.float16)

# New K/V to append to cache
k_new = jnp.ones((batch_size, 8, num_heads_kv, head_dim), dtype=jnp.float16)
v_new = jnp.ones((batch_size, 8, num_heads_kv, head_dim), dtype=jnp.float16)

# Current cache length before appending
cache_seqlens = jnp.array([1024, 1024, 1024, 1024], dtype=jnp.int32)

# Attention includes both cache and new K/V
output = flash_mha_with_kvcache(
    q=q,
    k_cache=k_cache,
    v_cache=v_cache,
    k=k_new,
    v=v_new,
    cache_seqlens=cache_seqlens,
    is_causal=True
)

Paged KV Cache

page_size = 64  # Tokens per page/block
num_blocks = 1000  # Total blocks in memory pool
batch_size = 4
max_pages_per_seq = 32  # Max pages/blocks allocated to each sequence, in this case 2048 tokens per sequence

# Paged KV cache as a shared memory pool
k_cache_paged = jnp.ones((num_blocks, page_size, num_heads_kv, head_dim), dtype=jnp.float16)
v_cache_paged = jnp.ones((num_blocks, page_size, num_heads_kv, head_dim), dtype=jnp.float16)

# Page table maps each sequence to its allocated blocks (shape: (batch, max_pages_per_seq))
page_table = jnp.array([
    [0, 1, 2, 3, ...],    # Sequence 0 uses blocks 0,1,2,3,...
    [10, 15, 20, 25, ...], # Sequence 1 uses blocks 10,15,20,25,...
    [5, 6, 7, 8, ...],    # Sequence 2 uses blocks 5,6,7,8,...
    [30, 31, 32, 33, ...], # Sequence 3 uses blocks 30,31,32,33,...
], dtype=jnp.int32)

# Cache lengths must be divisible by page_size for paged mode
cache_seqlens = jnp.array([512, 1024, 2048, 1536], dtype=jnp.int32)

output = flash_mha_with_kvcache(
    q=q,
    k_cache=k_cache_paged,
    v_cache=v_cache_paged,
    cache_seqlens=cache_seqlens,
    page_table=page_table,
    is_causal=True
)

Paged cache requirements:

  • page_table must have shape (batch, max_num_pages_per_seq)
  • Each page table entry contains the block index in the memory pool
  • cache_seqlens should be divisible by page_size

Advanced Features

Batch index mapping, map queries to different cache slots:

# 3 queries but 5 cache slots
cache_batch_idx = jnp.array([0, 2, 4], dtype=jnp.int32)

output = flash_mha_with_kvcache(
    q=q,  # Shape: (3, seqlen_q, h_q, d)
    k_cache=k_cache,  # Shape: (5, seqlen_cache, h_k, d)
    v_cache=v_cache,
    cache_batch_idx=cache_batch_idx,
    cache_seqlens=cache_seqlens
)

Ring Attention

When flash_mha detects that inputs are sharded along the sequence dimension, it automatically dispatches to ring attention:

import jax
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from flash_attn3_jax import flash_mha

devices = mesh_utils.create_device_mesh((8,))
mesh = Mesh(devices, axis_names=('seq',))

# Shard along sequence dimension
sharding = NamedSharding(mesh, P(None, 'seq', None, None))
q_sharded = jax.device_put(q, sharding)
k_sharded = jax.device_put(k, sharding)
v_sharded = jax.device_put(v, sharding)
output = flash_mha(q_sharded, k_sharded, v_sharded)

For optimal performance, enable XLA's latency-hiding scheduler to overlap communication with computation:

import os
os.environ["XLA_FLAGS"] = "--xla_gpu_enable_latency_hiding_scheduler=true"
# Must be set before importing JAX
import jax

Benchmarks

Forward Pass Comparison

Batch SeqLen Native (Flax) Flash Attn 2 Flash Attn 3 FA3 vs Flax FA2 vs Flax
Time (ms) TFLOPS Time (ms) TFLOPS Time (ms) TFLOPS Speedup Speedup
32 512 1.28 53.8 0.55 124.9 0.48 142.9 2.65x 2.32x
16 1,024 1.85 74.4 0.71 194.0 0.57 239.1 3.21x 2.64x
8 2,048 3.31 83.0 1.03 266.0 0.77 355.0 4.28x 3.23x
4 4,096 5.70 96.4 1.72 319.0 1.17 468.9 4.87x 3.33x
2 8,192 9.62 114.3 3.06 359.6 1.96 561.6 4.91x 3.15x
1 16,384 20.07 109.5 5.78 380.6 3.57 616.2 5.63x 3.48x

Backward Pass Comparison

Batch SeqLen Native (Flax) Flash Attn 2 Flash Attn 3 FA3 vs Flax FA2 vs Flax
Time (ms) TFLOPS Time (ms) TFLOPS Time (ms) TFLOPS Speedup Speedup
32 512 3.16 76.2 2.71 88.9 2.22 108.5 1.42x 1.17x
16 1,024 4.52 106.3 3.40 141.6 2.47 195.0 1.83x 1.33x
8 2,048 7.61 126.5 4.72 203.6 2.88 333.6 2.64x 1.60x
4 4,096 13.11 146.8 7.48 257.3 3.87 497.8 3.39x 1.76x
2 8,192 23.34 164.9 13.05 294.9 6.52 590.0 3.58x 1.79x
1 16,384 46.05 167.2 23.92 321.8 11.99 641.8 3.84x 1.93x

Log2 is applied for sequence length and time axis to ease seing the trends across all configuations equally.

Credits

This project is based on the official FlashAttention implementation FlashAttention repository. It is also heavily inspired by FlashAttention 2 implementation for JAX.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

flash_attn3_jax-0.2.3-cp312-cp312-manylinux_2_28_x86_64.whl (23.0 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.28+ x86-64

flash_attn3_jax-0.2.3-cp311-cp311-manylinux_2_28_x86_64.whl (23.0 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.28+ x86-64

File details

Details for the file flash_attn3_jax-0.2.3-cp312-cp312-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn3_jax-0.2.3-cp312-cp312-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 5bfa0c5f41ee6146d68d491311de39d228c805660be41124605e4c7ba3eb25d3
MD5 285c9f9b83a8d5cf505d8336b5b01a50
BLAKE2b-256 916541640ec21b679e0db0cabd46b462647135487de26b105ab394fa8661482d

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_attn3_jax-0.2.3-cp312-cp312-manylinux_2_28_x86_64.whl:

Publisher: publish.yml on kyutai-labs/flash-attn3-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_attn3_jax-0.2.3-cp311-cp311-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for flash_attn3_jax-0.2.3-cp311-cp311-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 02c9c6e61a5de8923ee6b049caf886326ebb7c14cef2e8705f6caf98ad42de33
MD5 5fb297fa5a03f1edc8d827e7a4993d15
BLAKE2b-256 6a7ec63ceac767dad8b38ea037131e4317f0348e5891c9b7e42d697cbe89cff2

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_attn3_jax-0.2.3-cp311-cp311-manylinux_2_28_x86_64.whl:

Publisher: publish.yml on kyutai-labs/flash-attn3-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page