Skip to main content

Fast, memory-efficient attention column reduction (e.g., sum, mean, max)

Project description

Flash-ColReduce

PyPI License Python 3.10+

Flash-ColReduce provides highly optimized Triton kernels for computing column-wise reductions of the attention matrix such as sum, mean, or max without materializing the full $O(N^2)$ attention weights.

This primitive is essential for KV-cache pruning, token importance estimation, and attention analysis in Large Language Models (LLMs) and Vision-Language Models (VLMs). It powers the visual token pruning in SparseVILA.

Highlights

  • 🚀 Efficient: Fused kernels compute column reductions in $O(N)$ memory.
  • 🧩 Flexible: Supports causal and non-causal attention with irregular shapes ($M \neq N$).
  • ✅ Exact: Uses online softmax for numerical precision and correct causal masking.

Prerequisites

  • Python: 3.10+
  • PyTorch: 2.1+ (with CUDA support)
  • Triton: 3.0.0+
  • GPU: NVIDIA GPU with Compute Capability 8.0+ (Ampere or newer recommended)

Installation

Install from PyPI:

pip install flash-colreduce

Or build from source:

git clone https://github.com/z-lab/flash-colreduce.git
cd flash-colreduce
pip install -e .

Usage

1. Non-Causal Attention

Compute a column-wise reduction of the attention matrix over the query dimension.

import torch
from flash_colreduce import flash_colreduce

q = torch.randn(8, 16, 512, 64, device="cuda", dtype=torch.float16)
k = torch.randn(8, 16, 512, 64, device="cuda", dtype=torch.float16)

flash_colreduce(q, k, reduction="sum")  # Shape: (8, 16, 512)
flash_colreduce(q, k, reduction="mean")  # Shape: (8, 16, 512)
flash_colreduce(q, k, reduction="max")  # Shape: (8, 16, 512)

2. Causal Attention

Handle autoregressive attention where $M \neq N$. The kernel applies a right-aligned causal mask matching KV-cached decoding behavior.

import torch
from flash_colreduce import flash_colreduce

q = torch.randn(1, 32, 128, 128, device="cuda", dtype=torch.float16)
k = torch.randn(1, 32, 4096, 128, device="cuda", dtype=torch.float16)

flash_colreduce(q, k, reduction="sum", is_causal=True)  # Shape: (1, 32, 4096)
flash_colreduce(q, k, reduction="mean", is_causal=True)  # Shape: (1, 32, 4096)
flash_colreduce(q, k, reduction="max", is_causal=True)  # Shape: (1, 32, 4096)

Performance

Flash-ColReduce achieves significant speedups and memory savings over naïve implementations. By fusing softmax and reduction into a single kernel, it avoids writing the $B \times H \times M \times N$ attention matrix to GPU memory.

Benchmark Results on NVIDIA RTX Pro 6000 Blackwell Benchmarked on NVIDIA RTX Pro 6000 Blackwell with FP16 precision

Development

Running Tests

pip install -e ".[test]"
pytest -v

Running Benchmarks

pip install -e ".[bench]"
python benchmarks/run.py

Citation

If you use Flash-ColReduce in your research, please cite the SparseVILA paper:

@inproceedings{khaki2025sparsevila,
  title = {{SparseVILA: Decoupling Visual Sparsity for Efficient VLM Inference}},
  author = {Khaki, Samir and Guo, Junxian and Tang, Jiaming and Yang, Shang and Chen, Yukang and Plataniotis, Konstantinos N and Lu, Yao and Han, Song and Liu, Zhijian},
  booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
  year = {2025}
}

License

MIT License

Acknowledgments

  • FlashAttention: The tiling and online softmax approach is heavily inspired by FlashAttention.
  • SparseVILA: The original project that motivated this primitive.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_colreduce-0.2.2.tar.gz (8.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flash_colreduce-0.2.2-py3-none-any.whl (8.1 kB view details)

Uploaded Python 3

File details

Details for the file flash_colreduce-0.2.2.tar.gz.

File metadata

  • Download URL: flash_colreduce-0.2.2.tar.gz
  • Upload date:
  • Size: 8.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for flash_colreduce-0.2.2.tar.gz
Algorithm Hash digest
SHA256 63ddc0fd4dad0e6bc0abc6f58f106bf365c8f1e2b1a8f5664e15086987771901
MD5 12853921ea68404a9ab15772db452e9f
BLAKE2b-256 2180d5985e96611abda4a66c7b203a08e25ae2fa519754c7d0f621de973d7bfb

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_colreduce-0.2.2.tar.gz:

Publisher: pypi.yml on z-lab/flash-colreduce

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_colreduce-0.2.2-py3-none-any.whl.

File metadata

File hashes

Hashes for flash_colreduce-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 0b6e6b3a30c7d08ca0e15cb2ef698347698b5d5c766ff51ecbb6035b768ca8de
MD5 545bf5f353c747916b4b7ba7c69d7545
BLAKE2b-256 649f2c00ce819bbd865d4f767b33ce14846a19a64f925a3855a4d31b7968fa62

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_colreduce-0.2.2-py3-none-any.whl:

Publisher: pypi.yml on z-lab/flash-colreduce

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page