Fast attention column-sum primitives with Triton kernels
Project description
Flash-ColSum
Efficient attention column-sum primitives with Triton kernels
Flash-ColSum provides efficient implementations for computing the mean attention column sum without materializing full attention matrices.
Originally developed for SparseVILA, Flash-ColSum is a general-purpose library for computing the mean column statistics of attention weights (token importance, attention analysis, etc).
Installation
Install from PyPI:
pip install flash-colsum
From source:
git clone https://github.com/your-org/flash-colsum.git
cd flash-colsum
pip install -e .
Quick Start
import torch
from flash_colsum import flash_colsum
# Non-causal (ViT, BERT, etc.)
Q = torch.randn(8, 16, 2048, 64, device="cuda", dtype=torch.float16)
K = Q.clone()
col_mean = flash_colsum(Q, K) # (8, 2048)
# Non-causal with CLS tokens (e.g., CLIP-style, first position is CLS)
# Average how much attention each token receives from the CLS token(s)
cls_col_mean = flash_colsum(Q, K, cls_len=1) # (8, 2048)
# Causal (GPT, retrieval, etc.)
Q = torch.randn(1, 32, 128, 128, device="cuda", dtype=torch.float16)
K = torch.randn(1, 32, 4096, 128, device="cuda", dtype=torch.float16)
col_mean = flash_colsum(Q, K, is_causal=True) # (1, 4096)
API
flash_colsum(query, key, scale=None, is_causal=False, cls_len=None)
Compute attention column means efficiently without materializing full attention matrix.
Parameters:
query(Tensor): Query tensor(B, H, S, D)or(1, H, Q_len, D)for causalkey(Tensor): Key tensor (same shape as query for non-causal), orK_len >= Q_lenfor causalscale(float, optional): Attention scale. Default:1/sqrt(D)is_causal(bool): Apply causal masking. Default:Falsecls_len(int, optional): In the non-causal case, average only over the firstcls_lenquery positions (e.g., CLS tokens). IfNone, averages over all query positions.
Returns:
Tensor:- Non-causal:
(B, S)mean per key position- with
cls_len=None: averaged over all query positions and heads - with
cls_len>0: averaged over the firstcls_lenquery positions and all heads
- with
- Causal:
(1, K_len)mean per key position (nocls_lensupport)
- Non-causal:
Performance
Flash-ColSum achieves significant speedups and memory savings over naïve implementations:
Benchmarked on NVIDIA RTX A6000 with FP16 precision
Benchmarked on NVIDIA GeForce RTX 5090 with FP16 precision
Development
Package Structure
Top-level layout:
flash-colsum/
├── flash_colsum/ # Library code
│ ├── __init__.py
│ ├── ops.py # Public API (flash_colsum, naive_colsum)
│ ├── baselines.py # Naive/reference implementations
│ ├── kernel_causal.py
│ ├── kernel_noncausal.py
│ └── kernel_noncausal_batched.py
├── benchmarks/ # Benchmark script
│ ├── __init__.py
│ └── benchmark_colsum.py
├── assets/ # Benchmark figures and other assets
├── tests/ # Pytest-based tests
│ ├── test_core.py
│ └── test_benchmarks.py
└── pyproject.toml
Evaluation (Tests & Benchmarks)
1. Evaluate correctness (pytest)
# Fast unit tests (correctness + error handling)
pytest -v -s
2. Evaluate efficiency (benchmarks, via pytest)
# Run only the benchmark sweeps (plot under benchmarks/out)
FLASH_COLSUM_RUN_BENCH=1 pytest tests/test_benchmarks.py -v -s
# Or: run full test suite + benchmark sweeps together
FLASH_COLSUM_RUN_BENCH=1 pytest -v -s
For more fine-grained control (single-point runs, custom sweeps), you can also call
the benchmark driver directly via python -m benchmarks.benchmark_colsum and pass
flags such as --sweep {noncausal_batched,noncausal,causal,all} and --out PATH.
Citation
If you use Flash-ColSum in your research, please cite our SparseVILA paper:
@InProceedings{Khaki_2025_ICCV,
author = {Khaki, Samir and Guo, Junxian and Tang, Jiaming and Yang, Shang and Chen, Yukang and Plataniotis, Konstantinos N. and Lu, Yao and Han, Song and Liu, Zhijian},
title = {SparseVILA: Decoupling Visual Sparsity for Efficient VLM Inference},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
month = {October},
year = {2025},
pages = {23784-23794}
}
License
MIT License. See LICENSE for details.
Acknowledgments
Flash-ColSum builds on ideas from:
- FlashAttention - Efficient attention kernels
- SparseVILA - Token Sparsity for vision-language models
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file flash_colsum-0.1.0.tar.gz.
File metadata
- Download URL: flash_colsum-0.1.0.tar.gz
- Upload date:
- Size: 12.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
aeffc7925d2c54be5a82616cc017ed80506dc442c8ae611941a6f23cf78c79b0
|
|
| MD5 |
a726b505ddab568b9e14e2a2814fd479
|
|
| BLAKE2b-256 |
0478a1e5f0e23b8966aaa2ef87bce11f82a2a0c707372bdb3a471dae41bdd074
|
Provenance
The following attestation bundles were made for flash_colsum-0.1.0.tar.gz:
Publisher:
python-publish.yml on z-lab/flash-colsum
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
flash_colsum-0.1.0.tar.gz -
Subject digest:
aeffc7925d2c54be5a82616cc017ed80506dc442c8ae611941a6f23cf78c79b0 - Sigstore transparency entry: 747702403
- Sigstore integration time:
-
Permalink:
z-lab/flash-colsum@c74180cb558d8b120e79dd8077daa54edb37775d -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/z-lab
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@c74180cb558d8b120e79dd8077daa54edb37775d -
Trigger Event:
release
-
Statement type:
File details
Details for the file flash_colsum-0.1.0-py3-none-any.whl.
File metadata
- Download URL: flash_colsum-0.1.0-py3-none-any.whl
- Upload date:
- Size: 13.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
299ae47a79c0c907a56dc38414281543e9aa4355298092e2c12fefb991b15c2c
|
|
| MD5 |
0b2ce7ee7128bd9b0593b67b276dcfec
|
|
| BLAKE2b-256 |
45a0b44d5a44a87971fbfa21541f36218a3e81a877702d90e0c8718b72b7311b
|
Provenance
The following attestation bundles were made for flash_colsum-0.1.0-py3-none-any.whl:
Publisher:
python-publish.yml on z-lab/flash-colsum
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
flash_colsum-0.1.0-py3-none-any.whl -
Subject digest:
299ae47a79c0c907a56dc38414281543e9aa4355298092e2c12fefb991b15c2c - Sigstore transparency entry: 747702405
- Sigstore integration time:
-
Permalink:
z-lab/flash-colsum@c74180cb558d8b120e79dd8077daa54edb37775d -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/z-lab
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@c74180cb558d8b120e79dd8077daa54edb37775d -
Trigger Event:
release
-
Statement type: