Fast, memory-efficient attention column reduction (e.g., sum, mean, max)
Project description
Flash-ColReduce
Flash-ColReduce provides highly optimized Triton kernels for computing column-wise reductions of the attention matrix such as sum, mean, or max without materializing the full $O(N^2)$ attention weights.
This primitive is essential for KV-cache pruning, token importance estimation, and attention analysis in Large Language Models (LLMs) and Vision-Language Models (VLMs). It powers the visual token pruning in SparseVILA.
Highlights
- 🚀 Efficient: Fused kernels compute column reductions in $O(N)$ memory.
- 🧩 Flexible: Supports causal and non-causal attention with irregular shapes ($M \neq N$).
- ✅ Exact: Uses online softmax for numerical precision and correct causal masking.
Prerequisites
- Python: 3.10+
- PyTorch: 2.1+ (with CUDA support)
- Triton: 3.0.0+
- GPU: NVIDIA GPU with Compute Capability 8.0+ (Ampere or newer recommended)
Installation
Install from PyPI:
pip install flash-colreduce
Or build from source:
git clone https://github.com/z-lab/flash-colreduce.git
cd flash-colreduce
pip install -e .
Usage
1. Non-Causal Attention
Compute a column-wise reduction of the attention matrix over the query dimension.
import torch
from flash_colreduce import flash_colreduce
q = torch.randn(8, 16, 512, 64, device="cuda", dtype=torch.float16)
k = torch.randn(8, 16, 512, 64, device="cuda", dtype=torch.float16)
flash_colreduce(q, k, reduction="sum") # Shape: (8, 16, 512)
flash_colreduce(q, k, reduction="mean") # Shape: (8, 16, 512)
flash_colreduce(q, k, reduction="max") # Shape: (8, 16, 512)
2. Causal Attention
Handle autoregressive attention where $M \neq N$. The kernel applies a right-aligned causal mask matching KV-cached decoding behavior.
import torch
from flash_colreduce import flash_colreduce
q = torch.randn(1, 32, 128, 128, device="cuda", dtype=torch.float16)
k = torch.randn(1, 32, 4096, 128, device="cuda", dtype=torch.float16)
flash_colreduce(q, k, reduction="sum", is_causal=True) # Shape: (1, 32, 4096)
flash_colreduce(q, k, reduction="mean", is_causal=True) # Shape: (1, 32, 4096)
flash_colreduce(q, k, reduction="max", is_causal=True) # Shape: (1, 32, 4096)
Performance
Flash-ColReduce achieves significant speedups and memory savings over naïve implementations. By fusing softmax and reduction into a single kernel, it avoids writing the $B \times H \times M \times N$ attention matrix to GPU memory.
Benchmarked on NVIDIA RTX Pro 6000 Blackwell with FP16 precision
Development
Running Tests
pip install -e ".[test]"
pytest -v
Running Benchmarks
pip install -e ".[bench]"
python benchmarks/run.py
Citation
If you use Flash-ColReduce in your research, please cite the SparseVILA paper:
@inproceedings{khaki2025sparsevila,
title = {{SparseVILA: Decoupling Visual Sparsity for Efficient VLM Inference}},
author = {Khaki, Samir and Guo, Junxian and Tang, Jiaming and Yang, Shang and Chen, Yukang and Plataniotis, Konstantinos N and Lu, Yao and Han, Song and Liu, Zhijian},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
year = {2025}
}
License
Acknowledgments
- FlashAttention: The tiling and online softmax approach is heavily inspired by FlashAttention.
- SparseVILA: The original project that motivated this primitive.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file flash_colreduce-0.2.2.tar.gz.
File metadata
- Download URL: flash_colreduce-0.2.2.tar.gz
- Upload date:
- Size: 8.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
63ddc0fd4dad0e6bc0abc6f58f106bf365c8f1e2b1a8f5664e15086987771901
|
|
| MD5 |
12853921ea68404a9ab15772db452e9f
|
|
| BLAKE2b-256 |
2180d5985e96611abda4a66c7b203a08e25ae2fa519754c7d0f621de973d7bfb
|
Provenance
The following attestation bundles were made for flash_colreduce-0.2.2.tar.gz:
Publisher:
pypi.yml on z-lab/flash-colreduce
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
flash_colreduce-0.2.2.tar.gz -
Subject digest:
63ddc0fd4dad0e6bc0abc6f58f106bf365c8f1e2b1a8f5664e15086987771901 - Sigstore transparency entry: 934612818
- Sigstore integration time:
-
Permalink:
z-lab/flash-colreduce@8c06f19bc157203486cb8cafd0570c17dccb8951 -
Branch / Tag:
refs/tags/v0.2.2 - Owner: https://github.com/z-lab
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
pypi.yml@8c06f19bc157203486cb8cafd0570c17dccb8951 -
Trigger Event:
release
-
Statement type:
File details
Details for the file flash_colreduce-0.2.2-py3-none-any.whl.
File metadata
- Download URL: flash_colreduce-0.2.2-py3-none-any.whl
- Upload date:
- Size: 8.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0b6e6b3a30c7d08ca0e15cb2ef698347698b5d5c766ff51ecbb6035b768ca8de
|
|
| MD5 |
545bf5f353c747916b4b7ba7c69d7545
|
|
| BLAKE2b-256 |
649f2c00ce819bbd865d4f767b33ce14846a19a64f925a3855a4d31b7968fa62
|
Provenance
The following attestation bundles were made for flash_colreduce-0.2.2-py3-none-any.whl:
Publisher:
pypi.yml on z-lab/flash-colreduce
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
flash_colreduce-0.2.2-py3-none-any.whl -
Subject digest:
0b6e6b3a30c7d08ca0e15cb2ef698347698b5d5c766ff51ecbb6035b768ca8de - Sigstore transparency entry: 934612871
- Sigstore integration time:
-
Permalink:
z-lab/flash-colreduce@8c06f19bc157203486cb8cafd0570c17dccb8951 -
Branch / Tag:
refs/tags/v0.2.2 - Owner: https://github.com/z-lab
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
pypi.yml@8c06f19bc157203486cb8cafd0570c17dccb8951 -
Trigger Event:
release
-
Statement type: