Flash Attention Implementation with Multiple Backend Support and Sharding This module provides a flexible implementation of Flash Attention with support for different backends (GPU, TPU, CPU) and platforms (Triton, Pallas, JAX).
Project description
JAX-Flash-Attention2
A flexible and efficient implementation of Flash Attention 2.0 for JAX, supporting multiple backends (GPU/TPU/CPU) and platforms (Triton/Pallas/JAX).
Installation
pip install jax-flash-attn2
Basic Usage
import jax
import jax.numpy as jnp
import jax_flash_attn2 as jfa
# Initialize the FlashAttention module with desired configuration
flash_attention = jfa.FlashAttention(
jfa.AttentionConfig(
platform=jfa.Platform.TRITON, # Options: TRITON, PALLAS, JAX
backend=jfa.Backend.GPU, # Options: GPU, TPU, CPU
)
)
# Create sample inputs
batch_size, num_heads, seq_len, head_dim = 2, 4, 512, 64
query = jax.random.normal(jax.random.PRNGKey(0), (batch_size, num_heads * 4, seq_len, head_dim), "f2")
key = jax.random.normal(jax.random.PRNGKey(1), (batch_size, num_heads, seq_len, head_dim), "f2")
value = jax.random.normal(jax.random.PRNGKey(2), (batch_size, num_heads, seq_len, head_dim), "f2")
# Compute attention
output = flash_attention(
query=query,
key=key,
value=value,
causal=True # Enable causal masking for decoder-only models
)
# output shape: (batch_size, num_heads, seq_len, head_dim)
Advanced Usage
With Attention Mask
# Create an attention mask (1 = attend, 0 = mask)
attention_mask = jnp.ones((batch_size, 1, seq_len, seq_len)) # Allow full attention
# For example, mask the first 100 tokens from attending to the last 100 tokens
attention_mask = attention_mask.at[:, :, :100, -100:].set(0)
output = flash_attention(
query=query,
key=key,
value=value,
attention_mask=attention_mask,
causal=False # Using explicit mask instead of causal
)
With Attention Bias
# Create an attention bias
bias = jnp.zeros((batch_size, 1, seq_len, seq_len))
# Add position-dependent bias
for i in range(seq_len):
for j in range(seq_len):
bias = bias.at[:, :, i, j].set(1.0 / (1.0 + abs(i - j)))
output = flash_attention(
query=query,
key=key,
value=value,
bias=bias
)
With Dropout
output = flash_attention(
query=query,
key=key,
value=value,
dropout_prob=0.1,
dropout_seed=42,
causal=True
)
Flax Modules with JFA2
Here's an example of integrating jax-flash-attn2 within a Transformer model implemented in Flax:
import typing as tp
from functools import partial
import chex
import flax.nnx as nn
import jax
import jax.numpy as jnp
import jax_flash_attn2 as jfa
class JFAttention2(nn.Module):
def __init__(
self,
hidden_size: int,
head_dim: int,
num_attention_heads: int,
num_key_value_heads: int,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs = None,
):
if rngs is None:
rngs = nn.Rngs(0)
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.rngs = rngs
self.hidden_size = hidden_size
self.head_dim = head_dim
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.num_key_value_groups = num_attention_heads // num_key_value_heads
if self.num_key_value_groups == 1:
assert num_attention_heads == num_key_value_heads
linear_class = partial(
nn.Linear,
dtype=dtype,
param_dtype=param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(0.02),
precision=precision,
rngs=rngs,
)
self.q_proj = linear_class(hidden_size, num_attention_heads * self.head_dim)
self.k_proj = linear_class(hidden_size, num_key_value_heads * self.head_dim)
self.v_proj = linear_class(hidden_size, num_key_value_heads * self.head_dim)
self.o_proj = linear_class(num_attention_heads * self.head_dim, hidden_size)
config = jfa.AttentionConfig(platform=jfa.Platform.TRITON, backend=jfa.Backend.GPU)
self.jfa2 = jfa.FlashAttention(config)
def __call__(
self,
hidden_states: chex.Array,
attention_mask: chex.Array,
causal: bool = True,
) -> tp.Tuple[chex.Array, chex.Array]:
batch_size, sequence_length = hidden_states.shape[:2]
query_states, key_states, value_states = (
self.q_proj(hidden_states),
self.k_proj(hidden_states),
self.v_proj(hidden_states),
)
qshape = (
batch_size,
sequence_length,
self.num_attention_heads,
self.head_dim,
)
kv_shape = (
batch_size,
sequence_length,
self.num_key_value_heads,
self.head_dim,
)
query_states = query_states.reshape(qshape)
key_states = key_states.reshape(kv_shape)
value_states = value_states.reshape(kv_shape)
attn_output = self.jfa2.forward(
query_states.astype(jnp.bfloat16),
key_states.astype(jnp.bfloat16),
value_states.astype(jnp.bfloat16),
jnp.where(attention_mask, 0, jnp.finfo(query_states).min).astype(jnp.bfloat16),
causal=causal,
)
attn_output = jnp.reshape(attn_output, (batch_size, sequence_length, -1))
attn_output = self.o_proj(attn_output)
return attn_output
Platform-Specific Examples
Using JAX Backend
jax_flash_attn = jfa.FlashAttention(
jfa.AttentionConfig(
platform=jfa.Platform.JAX,
backend=jfa.Backend.CPU, # Works on any hardware
)
)
output = jax_flash_attn(query, key, value)
Using Pallas for TPU
tpu_flash_attn = jfa.FlashAttention(
jfa.AttentionConfig(
platform=jfa.Platform.PALLAS,
backend=jfa.Backend.TPU,
)
)
output = tpu_flash_attn(query, key, value)
Integration with JAX Transformations
@jax.jit
def attention_forward(q, k, v, mask=None):
return flash_attention(
query=q,
key=k,
value=v,
attention_mask=mask,
causal=True
)
# JIT-compiled function
fast_attention = attention_forward(query, key, value)
# With gradient computation
def loss_fn(q, k, v):
attn_output = flash_attention(q, k, v, causal=True)
return jnp.mean(attn_output)
grads = jax.grad(loss_fn)(query, key, value)
Limitations
- Triton platform is only available on NVIDIA GPUs.
- Some platform-backend combinations are not supported (see table above).
- Custom attention masks are not yet supported (use bias instead).
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
Citation
If you use this implementation in your research, please cite:
@software{jax_flash_attn2,
title = {JAX Flash Attention 2.0},
year = {2024},
url = {https://github.com/erfanzar/jax-flash-attn2}
}
refrence citations
@inproceedings{dao2022flashattention,
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2022}
}
@inproceedings{dao2023flashattention2,
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
author={Dao, Tri},
booktitle={International Conference on Learning Representations (ICLR)},
year={2024}
}
Acknowledgments And Refrences
- All of kernels are copied from
EasyDeL
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax_flash_attn2-0.0.3.tar.gz.
File metadata
- Download URL: jax_flash_attn2-0.0.3.tar.gz
- Upload date:
- Size: 27.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.0.1 CPython/3.10.12 Linux/6.8.0-52-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ab3f83fd1b93d7853d2e31e779f47cdf5bda70196ec9f1ee795014c96511e27c
|
|
| MD5 |
c4cd736615f9792e8ed7c21fe823feea
|
|
| BLAKE2b-256 |
7c3a30eec3e6b398abcc3131a90f2dda51c37ee01c2b2121f02de03364a46b73
|
File details
Details for the file jax_flash_attn2-0.0.3-py3-none-any.whl.
File metadata
- Download URL: jax_flash_attn2-0.0.3-py3-none-any.whl
- Upload date:
- Size: 38.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.0.1 CPython/3.10.12 Linux/6.8.0-52-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
98f28dcaf5b8c5e8d3005558be720da8af099573e84250085641eeab9b032ff9
|
|
| MD5 |
99bbc896191a34ca969214d85c20d4c5
|
|
| BLAKE2b-256 |
ccda2f0b47b854ef35d3f9ae738ef4e162b3b1f1e21cc34c17b8c304127098f6
|