Flash Attention Implementation with Multiple Backend Support and Sharding This module provides a flexible implementation of Flash Attention with support for different backends (GPU, TPU, CPU) and platforms (Triton, Pallas, JAX).
Project description
jax-flash-attn2
A flexible and efficient implementation of Flash Attention 2.0 for JAX, supporting multiple backends (GPU/TPU/CPU) and platforms (Triton/Pallas/JAX).
Features
- 🚀 Multiple backend support: GPU, TPU, and CPU
- 🔧 Multiple platform implementations: Triton, Pallas, and JAX
- ⚡ Efficient caching of attention instances
- 🔄 Support for Grouped Query Attention (GQA) and headdims up to 256.
- 📊 JAX sharding-friendly implementation
- 🎯 Automatic platform selection based on backend
- 🧩 Compatible with existing JAX mesh patterns
Installation
pip install jax-flash-attn2
Quick Start
from jax_flash_attn2 import get_cached_flash_attention
# Get a cached attention instance
attention = get_cached_flash_attention(
backend="gpu", # 'gpu', 'tpu', or 'cpu'
platform="triton", # 'triton', 'pallas', or 'jax'
blocksize_q=64, # BLOCK SIZE Q
blocksize_k=128, # BLOCK SIZE K
softmax_scale=headdim ** -0.5 # Optional scaling factor
)
# Use with your tensors
outputs = attention(
query=query_states,
key=key_states,
value=value_states,
bias=attention_bias, # Optional
)
Usage with JAX Sharding
with mesh:
attention_outputs = get_cached_flash_attention(
backend="gpu",
platform="triton",
blocksize_q=128,
blocksize_k=128,
softmax_scale=None,
)(
query=with_sharding_constraint(query_states, qps).astype(dtype),
key=with_sharding_constraint(key_states, kps).astype(dtype),
value=with_sharding_constraint(value_states, vps).astype(dtype),
bias=with_sharding_constraint(bias, bps).astype(dtype),
)
Supported Configurations
Backends
gpu
: CUDA-capable GPUstpu
: Google Cloud TPUscpu
: CPU fallback
Platforms
triton
: Optimized for NVIDIA GPUspallas
: Optimized for TPUs and supported on GPUsjax
: Universal fallback, supports all backends
Valid Backend-Platform Combinations
Backend | Supported Platforms |
---|---|
GPU | Triton, Pallas, JAX |
TPU | Pallas, JAX |
CPU | JAX |
Advanced Configuration
Custom Block Sizes
attention = get_cached_flash_attention(
backend="gpu",
platform="triton",
blocksize_q=128, # Customize query block size
blocksize_k=128, # Customize key block size
softmax_scale=1.0, # Custom softmax scaling
)
Environment Variables
FORCE_MHA
: Set to "true", "1", or "on" to force using MHA implementation even for GQA cases
Performance Tips
-
Block Sizes: Default block sizes (128) work well for most cases, but you might want to tune them for your specific hardware and model architecture.
-
Platform Selection:
- For NVIDIA GPUs: prefer
triton
- For TPUs: prefer
pallas
- For CPU or fallback: use
jax
- For NVIDIA GPUs: prefer
-
Caching: The
get_cached_flash_attention
function automatically caches instances based on parameters. No need to manage caching manually.
Requirements
- JAX
- einops
- chex
- jax.experimental.pallas (for TPU support)
- triton (for GPU optimized implementation)
Limitations
- Triton platform is only available on NVIDIA GPUs.
- Some platform-backend combinations are not supported (see table above).
- Custom attention masks are not yet supported (use bias instead).
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
Citation
If you use this implementation in your research, please cite:
@software{jax_flash_attn2,
title = {JAX Flash Attention 2.0},
year = {2024},
url = {https://github.com/erfanzar/jax-flash-attn2}
}
Acknowledgments And Refrences
- This implementation (MHA) is based on:
- Flash Attention 2.0 paper
- JAX ecosystem tools and libraries
- Triton and Pallas optimization frameworks
-
Custom Triton Uses
JAX-Triton
-
All of kernels are copied from
EasyDeL
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file jax_flash_attn2-0.0.1.tar.gz
.
File metadata
- Download URL: jax_flash_attn2-0.0.1.tar.gz
- Upload date:
- Size: 35.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.8.4 CPython/3.10.12 Linux/5.15.153.1-microsoft-standard-WSL2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c76947468451f41d4c9d2fe59c868c13bffdb7e96d05354491567a542f48c815 |
|
MD5 | 78211d0cf9ed7c68ca72e7e2279bd22a |
|
BLAKE2b-256 | 39be7edd549bc129222063dc504a50943db32024d37d72f0987ff469f79c7418 |
File details
Details for the file jax_flash_attn2-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: jax_flash_attn2-0.0.1-py3-none-any.whl
- Upload date:
- Size: 42.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.8.4 CPython/3.10.12 Linux/5.15.153.1-microsoft-standard-WSL2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 161f2baf1bc3a11e80fa30717521769267c5840cabb39af2b5b012f9e1e0ebdb |
|
MD5 | 8d7ca7e9095345343bca1488389d2743 |
|
BLAKE2b-256 | 33bf6f165b9632be5dd07aee61201c2e2a29bb857eab8e0ecbb94b5d04de2c98 |