Skip to main content

Fast, memory-efficient attention column reduction (e.g., sum, mean, max)

Project description

Flash-ColReduce

PyPI License Python 3.10+

Flash-ColReduce provides highly optimized Triton kernels for computing column-wise reductions of the attention matrix such as sum, mean, or max without materializing the full $O(N^2)$ attention weights.

This primitive is essential for KV-cache pruning, token importance estimation, and attention analysis in Large Language Models (LLMs) and Vision-Language Models (VLMs). It powers the visual token pruning in SparseVILA.

Highlights

  • 🚀 Efficient: Fused kernels compute column reductions in $O(N)$ memory.
  • 🧩 Flexible: Supports causal and non-causal attention with irregular shapes ($M \neq N$).
  • ✅ Exact: Uses online softmax for numerical precision and correct causal masking.

Prerequisites

  • Python: 3.10+
  • PyTorch: 2.1+ (with CUDA support)
  • Triton: 3.0.0+
  • GPU: NVIDIA GPU with Compute Capability 8.0+ (Ampere or newer recommended)

Installation

Install from PyPI:

pip install flash-colreduce

Or build from source:

git clone https://github.com/z-lab/flash-colreduce.git
cd flash-colreduce
pip install -e .

Usage

1. Non-Causal Attention

Compute a column-wise reduction of the attention matrix over the query dimension.

import torch
from flash_colreduce import flash_colreduce

q = torch.randn(8, 16, 512, 64, device="cuda", dtype=torch.float16)
k = torch.randn(8, 16, 512, 64, device="cuda", dtype=torch.float16)

flash_colreduce(q, k, reduction="sum")  # Shape: (8, 16, 512)
flash_colreduce(q, k, reduction="mean")  # Shape: (8, 16, 512)
flash_colreduce(q, k, reduction="max")  # Shape: (8, 16, 512)

2. Causal Attention

Handle autoregressive attention where $M \neq N$. The kernel applies a right-aligned causal mask matching KV-cached decoding behavior.

import torch
from flash_colreduce import flash_colreduce

q = torch.randn(1, 32, 128, 128, device="cuda", dtype=torch.float16)
k = torch.randn(1, 32, 4096, 128, device="cuda", dtype=torch.float16)

flash_colreduce(q, k, reduction="sum", is_causal=True)  # Shape: (1, 32, 4096)
flash_colreduce(q, k, reduction="mean", is_causal=True)  # Shape: (1, 32, 4096)
flash_colreduce(q, k, reduction="max", is_causal=True)  # Shape: (1, 32, 4096)

Performance

Flash-ColReduce achieves significant speedups and memory savings over naïve implementations. By fusing softmax and reduction into a single kernel, it avoids writing the $B \times H \times M \times N$ attention matrix to GPU memory.

Benchmark Results on NVIDIA RTX Pro 6000 Blackwell Benchmarked on NVIDIA RTX Pro 6000 Blackwell with FP16 precision

Development

Running Tests

pip install -e ".[test]"
pytest -v

Running Benchmarks

pip install -e ".[bench]"
python benchmarks/run.py

Citation

If you use Flash-ColReduce in your research, please cite the SparseVILA paper:

@inproceedings{khaki2025sparsevila,
  title = {{SparseVILA: Decoupling Visual Sparsity for Efficient VLM Inference}},
  author = {Khaki, Samir and Guo, Junxian and Tang, Jiaming and Yang, Shang and Chen, Yukang and Plataniotis, Konstantinos N and Lu, Yao and Han, Song and Liu, Zhijian},
  booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
  year = {2025}
}

License

MIT License

Acknowledgments

  • FlashAttention: The tiling and online softmax approach is heavily inspired by FlashAttention.
  • SparseVILA: The original project that motivated this primitive.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_colreduce-0.2.1.tar.gz (8.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flash_colreduce-0.2.1-py3-none-any.whl (8.0 kB view details)

Uploaded Python 3

File details

Details for the file flash_colreduce-0.2.1.tar.gz.

File metadata

  • Download URL: flash_colreduce-0.2.1.tar.gz
  • Upload date:
  • Size: 8.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for flash_colreduce-0.2.1.tar.gz
Algorithm Hash digest
SHA256 354408477dfd0720c169a6c2b5fcd8feeffd21f456c4f4e10af6992d9066261d
MD5 7f159cb4cc515bdcb34864a5514d40b6
BLAKE2b-256 9e51b650337027c30a1ce5224cad6e3811c88bf911adcc11ef84bd7ddb2525ea

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_colreduce-0.2.1.tar.gz:

Publisher: pypi.yml on z-lab/flash-colreduce

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_colreduce-0.2.1-py3-none-any.whl.

File metadata

File hashes

Hashes for flash_colreduce-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 42c9386438dcf2f9c8dd9ffe1b75810228099e36216527612b8fccc4605e83ca
MD5 330495e7211a8031658e665a52d06ec9
BLAKE2b-256 95574271f3e79389592327e16f5c6feba2953ea9d50f809f00e0174087078629

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_colreduce-0.2.1-py3-none-any.whl:

Publisher: pypi.yml on z-lab/flash-colreduce

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page