Flash Sparse Attention: Fast and Memory-Efficient Trainable Sparse Attention
Project description
English | 简体中文
Flash-Sparse-Attention is a high-performance trainable sparse attention implementation that combines Flash Attention's memory efficiency with sparse computation for handling extremely long sequences in Transformer models.
Key Features
[!NOTE] Support for arbitrary mask and bias shapes is available in this branch. The current main branch no longer maintains that feature set.
Supported Features
- Forward and backward passes for dense attention, sparse attention, and gated attention
- Regular batched inputs and varlen inputs
- Causal attention and local window attention
- Arbitrary combinations of Q and KV sequence lengths, with head dimensions up to 256
- Grouped Query Attention and Multi Query Attention
- Sparse softmax threshold control
- Gated attention with gate inputs and configurable gating sparsity
- Split-KV path optimization for decoding workloads
Features We Aim to Support
- Paged Attention
- TMA, WGMMA, and FP8 low precision
- Sequence parallelism
Installation
Requirements
- Linux: Ubuntu 22.04 or later
- NVIDIA GPU: Compute Capability 8.0 or higher
- Runtime: NVIDIA driver and runtime compatible with your PyTorch and Triton installation
- Python: 3.9 or later
- PyTorch: 2.5.1 or later
- Triton: Installed automatically as a default dependency
Install
Install from PyPI:
pip install flash-sparse-attn
To install from source:
git clone https://github.com/flash-algo/flash-sparse-attn.git
cd flash-sparse-attn
pip install .
Install via HuggingFace Kernel
You can also load the kernels directly from HuggingFace Kernel without installing the package:
from kernels import get_kernel
fsa = get_kernel("JingzeShi/flash-sparse-attention", version=1)
out = fsa.flash_dense_attn_func(q, k, v, is_causal=True)
out = fsa.flash_sparse_attn_func(q, k, v, is_causal=True, softmax_threshold=0.01)
out = fsa.flash_gated_attn_func(q, k, v, alpha, delta, is_causal=True)
Requires pip install kernels.
Quick Start
Basic Usage
Below are examples for the three common attention variants:
import torch
from flash_sparse_attn.ops.triton.interface import (
flash_dense_attn_func,
flash_sparse_attn_func,
flash_gated_attn_func,
)
dtype = torch.bfloat16
device = torch.device("cuda")
batch_size, seqlen_q, seqlen_k, num_heads, num_kv_heads, head_dim = 2, 1024, 1024, 8, 2, 64
query = torch.randn(batch_size, seqlen_q, num_heads, head_dim, dtype=dtype, device=device)
key = torch.randn(batch_size, seqlen_k, num_kv_heads, head_dim, dtype=dtype, device=device)
value = torch.randn(batch_size, seqlen_k, num_kv_heads, head_dim, dtype=dtype, device=device)
Dense Attention
Use this when you do not need explicit sparsification but still want an efficient attention kernel.
output_dense = flash_dense_attn_func(
query=query,
key=key,
value=value,
is_causal=True,
)
print(output_dense.shape)
Sparse Attention
Use this when you want to skip low-contribution attention weights through softmax_threshold and reduce effective compute on long sequences.
output_sparse = flash_sparse_attn_func(
query=query,
key=key,
value=value,
is_causal=True,
softmax_threshold=1.0,
)
print(output_sparse.shape)
Gated Attention
Use this when you need explicit gating signals for sparse attention. alpha controls query-side gating and delta controls key-side gating.
alpha = torch.randn(batch_size, num_heads, seqlen_q, device=device, dtype=dtype)
delta = torch.randn(batch_size, num_kv_heads, seqlen_k, device=device, dtype=dtype)
output_gated = flash_gated_attn_func(
query=query,
key=key,
value=value,
alpha=alpha,
delta=delta,
is_causal=True,
softmax_threshold=1.0,
gate_threshold=1.0,
)
print(output_gated.shape)
Performance
The following benchmarks were collected on SM120 and cover forward, backward, and decoding workloads. They include Dense, Sparse, and Gated implementations, with FlashAttention as a baseline.
Forward Performance
Backward Performance
Decode Performance
Benchmarking
Benchmark scripts are located under tests, covering forward, backward, and decoding performance.
By default, these scripts use the attention projection layers from the Qwen model family to generate Q, K, and V states with distributions closer to real LLM workloads, and they build input sequences from the Needle-in-a-Haystack dataset.
Forward Performance
python tests/benchmark_forward.py
Backward Performance
python tests/benchmark_backward.py
Decode Performance
python tests/benchmark_decode.py
Citation
If you use FSA in your research, please cite:
@misc{shi2025trainabledynamicmasksparse,
title={Trainable Dynamic Mask Sparse Attention},
author={Jingze Shi and Yifan Wu and Bingheng Wu and Yiran Peng and Liangdong Wang and Guang Liu and Yuyu Luo},
year={2025},
eprint={2508.02124},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2508.02124},
}
Acknowledgments
This project builds upon and integrates several excellent works:
- OpenSeek - Kernel development support
- Flash-Attention - Memory-efficient attention computation
- NVIDIA CUTLASS - High-performance matrix operations library
We thank the open-source community for its contributions to efficient Transformer implementations.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file flash_sparse_attn-2.0.1.tar.gz.
File metadata
- Download URL: flash_sparse_attn-2.0.1.tar.gz
- Upload date:
- Size: 394.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.20
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5585f9da85dbe25e57027391ffcdeba30161ac304a4b98ce3274b80f7b45860b
|
|
| MD5 |
99d62bac92f3e82ab5e0283c19407f4b
|
|
| BLAKE2b-256 |
60eb25cbbc4a7ee5ac622d0bf790334b77671514662da173b5c66e2110b97607
|
File details
Details for the file flash_sparse_attn-2.0.1-py3-none-any.whl.
File metadata
- Download URL: flash_sparse_attn-2.0.1-py3-none-any.whl
- Upload date:
- Size: 376.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.20
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e07e68fd3177952c19b1ab7fd76d31e6f2ec49c2ed45e1434502e084a1c22a78
|
|
| MD5 |
b305a6d4d5bd8c90ef83176333b23f99
|
|
| BLAKE2b-256 |
46efc4fbc0d612dc5b3244e2e76bc7f8eec8d409aa6ed6e684d4f314f72c499f
|