Skip to main content

Flash Sparse Attention: Fast and Memory-Efficient Trainable Sparse Attention

Project description

flash-algo

English | 简体中文

Flash-Sparse-Attention is a high-performance trainable sparse attention implementation that combines Flash Attention's memory efficiency with sparse computation for handling extremely long sequences in Transformer models.

Key Features

[!NOTE] Support for arbitrary mask and bias shapes is available in this branch. The current main branch no longer maintains that feature set.

Supported Features

  • Forward and backward passes for dense attention, sparse attention, and gated attention
  • Regular batched inputs and varlen inputs
  • Causal attention and local window attention
  • Arbitrary combinations of Q and KV sequence lengths, with head dimensions up to 256
  • Grouped Query Attention and Multi Query Attention
  • Sparse softmax threshold control
  • Gated attention with gate inputs and configurable gating sparsity
  • Split-KV path optimization for decoding workloads

Features We Aim to Support

  • Paged Attention
  • TMA, WGMMA, and FP8 low precision
  • Sequence parallelism

Installation

Requirements

  • Linux: Ubuntu 22.04 or later
  • NVIDIA GPU: Compute Capability 8.0 or higher
  • Runtime: NVIDIA driver and runtime compatible with your PyTorch and Triton installation
  • Python: 3.9 or later
  • PyTorch: 2.5.1 or later
  • Triton: Installed automatically as a default dependency

Install

Install from PyPI:

pip install flash-sparse-attn

To install from source:

git clone https://github.com/flash-algo/flash-sparse-attn.git
cd flash-sparse-attn
pip install .

Quick Start

Basic Usage

Below are examples for the three common attention variants:

import torch
from flash_sparse_attn.ops.triton.interface import (
    flash_dense_attn_func,
    flash_sparse_attn_func,
    flash_gated_attn_func,
)

dtype = torch.bfloat16
device = torch.device("cuda")
batch_size, seqlen_q, seqlen_k, num_heads, num_kv_heads, head_dim = 2, 1024, 1024, 8, 2, 64

query = torch.randn(batch_size, seqlen_q, num_heads, head_dim, dtype=dtype, device=device)
key = torch.randn(batch_size, seqlen_k, num_kv_heads, head_dim, dtype=dtype, device=device)
value = torch.randn(batch_size, seqlen_k, num_kv_heads, head_dim, dtype=dtype, device=device)

Dense Attention

Use this when you do not need explicit sparsification but still want an efficient attention kernel.

output_dense = flash_dense_attn_func(
    query=query,
    key=key,
    value=value,
    is_causal=True,
)

print(output_dense.shape)

Sparse Attention

Use this when you want to skip low-contribution attention weights through softmax_threshold and reduce effective compute on long sequences.

output_sparse = flash_sparse_attn_func(
    query=query,
    key=key,
    value=value,
    is_causal=True,
    softmax_threshold=1.0,
)

print(output_sparse.shape)

Gated Attention

Use this when you need explicit gating signals for sparse attention. alpha controls query-side gating and delta controls key-side gating.

alpha = torch.randn(batch_size, num_heads, seqlen_q, device=device, dtype=dtype)
delta = torch.randn(batch_size, num_kv_heads, seqlen_k, device=device, dtype=dtype)

output_gated = flash_gated_attn_func(
    query=query,
    key=key,
    value=value,
    alpha=alpha,
    delta=delta,
    is_causal=True,
    softmax_threshold=1.0,
    gate_threshold=1.0,
)

print(output_gated.shape)

Performance

The following benchmarks were collected on SM120 and cover forward, backward, and decoding workloads. They include Dense, Sparse, and Gated implementations, with FlashAttention as a baseline.

Forward Performance

Attention forward speed, head dim 128

Backward Performance

Attention backward speed, head dim 128

Decode Performance

Attention decode speed, head dim 128

Benchmarking

Benchmark scripts are located under tests, covering forward, backward, and decoding performance.

By default, these scripts use the attention projection layers from the Qwen model family to generate Q, K, and V states with distributions closer to real LLM workloads, and they build input sequences from the Needle-in-a-Haystack dataset.

Forward Performance

python tests/benchmark_forward.py

Backward Performance

python tests/benchmark_backward.py

Decode Performance

python tests/benchmark_decode.py

Citation

If you use FSA in your research, please cite:

@misc{shi2025trainabledynamicmasksparse,
      title={Trainable Dynamic Mask Sparse Attention},
      author={Jingze Shi and Yifan Wu and Bingheng Wu and Yiran Peng and Liangdong Wang and Guang Liu and Yuyu Luo},
      year={2025},
      eprint={2508.02124},
      archivePrefix={arXiv},
      primaryClass={cs.AI},
      url={https://arxiv.org/abs/2508.02124},
}

Acknowledgments

This project builds upon and integrates several excellent works:

We thank the open-source community for its contributions to efficient Transformer implementations.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_sparse_attn-2.0.0.tar.gz (131.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flash_sparse_attn-2.0.0-py3-none-any.whl (90.6 kB view details)

Uploaded Python 3

File details

Details for the file flash_sparse_attn-2.0.0.tar.gz.

File metadata

  • Download URL: flash_sparse_attn-2.0.0.tar.gz
  • Upload date:
  • Size: 131.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.20

File hashes

Hashes for flash_sparse_attn-2.0.0.tar.gz
Algorithm Hash digest
SHA256 324aabb46f7bcab340cac8c6c784cf9f8848f1a75863f39370a5f94b06e61c62
MD5 393988f09412a8ff7196995d0ce2b783
BLAKE2b-256 5d9bb36e5646cf90c8ae42024eb38613fcbeb1ade86102aa7f380d81466e20a5

See more details on using hashes here.

File details

Details for the file flash_sparse_attn-2.0.0-py3-none-any.whl.

File metadata

File hashes

Hashes for flash_sparse_attn-2.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 365d365b80e503ecc243f57c05e03a4271a873076e579c65791d42357e7a15be
MD5 fcf38aea68fc716fdeac9c45e90a92a2
BLAKE2b-256 8c06485a907e18ab4e06285f23e434163e0e30033d9a5e0b862d28e5000b3fd6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page