Skip to main content

Fast attention column-sum primitives with Triton kernels

Project description

Flash-ColSum

Efficient attention column-sum primitives with Triton kernels

PyPI License

Flash-ColSum provides efficient implementations for computing the mean attention column sum without materializing full attention matrices.

Originally developed for SparseVILA, Flash-ColSum is a general-purpose library for computing the mean column statistics of attention weights (token importance, attention analysis, etc).

Installation

Install from PyPI:

pip install flash-colsum

From source:

git clone https://github.com/your-org/flash-colsum.git
cd flash-colsum
pip install -e .

Quick Start

import torch
from flash_colsum import flash_colsum

# Non-causal (ViT, BERT, etc.)
Q = torch.randn(8, 16, 2048, 64, device="cuda", dtype=torch.float16)
K = Q.clone()
col_mean = flash_colsum(Q, K)  # (8, 2048)

# Non-causal with CLS tokens (e.g., CLIP-style, first position is CLS)
# Average how much attention each token receives from the CLS token(s)
cls_col_mean = flash_colsum(Q, K, cls_len=1)  # (8, 2048)

# Causal (GPT, retrieval, etc.)
Q = torch.randn(1, 32, 128, 128, device="cuda", dtype=torch.float16)
K = torch.randn(1, 32, 4096, 128, device="cuda", dtype=torch.float16)
col_mean = flash_colsum(Q, K, is_causal=True)  # (1, 4096)

API

flash_colsum(query, key, scale=None, is_causal=False, cls_len=None)

Compute attention column means efficiently without materializing full attention matrix.

Parameters:

  • query (Tensor): Query tensor (B, H, S, D) or (1, H, Q_len, D) for causal
  • key (Tensor): Key tensor (same shape as query for non-causal), or K_len >= Q_len for causal
  • scale (float, optional): Attention scale. Default: 1/sqrt(D)
  • is_causal (bool): Apply causal masking. Default: False
  • cls_len (int, optional): In the non-causal case, average only over the first cls_len query positions (e.g., CLS tokens). If None, averages over all query positions.

Returns:

  • Tensor:
    • Non-causal: (B, S) mean per key position
      • with cls_len=None: averaged over all query positions and heads
      • with cls_len>0: averaged over the first cls_len query positions and all heads
    • Causal: (1, K_len) mean per key position (no cls_len support)

Performance

Flash-ColSum achieves significant speedups and memory savings over naïve implementations:

A6000 Benchmark Results Benchmarked on NVIDIA RTX A6000 with FP16 precision

A6000 Benchmark Results Benchmarked on NVIDIA GeForce RTX 5090 with FP16 precision

Development

Package Structure

Top-level layout:

flash-colsum/
├── flash_colsum/          # Library code
│   ├── __init__.py
│   ├── ops.py             # Public API (flash_colsum, naive_colsum)
│   ├── baselines.py       # Naive/reference implementations
│   ├── kernel_causal.py
│   ├── kernel_noncausal.py
│   └── kernel_noncausal_batched.py
├── benchmarks/            # Benchmark script
│   ├── __init__.py
│   └── benchmark_colsum.py
├── assets/                # Benchmark figures and other assets
├── tests/                 # Pytest-based tests
│   ├── test_core.py
│   └── test_benchmarks.py
└── pyproject.toml

Evaluation (Tests & Benchmarks)

1. Evaluate correctness (pytest)

# Fast unit tests (correctness + error handling)
pytest -v -s

2. Evaluate efficiency (benchmarks, via pytest)

# Run only the benchmark sweeps (plot under benchmarks/out)
FLASH_COLSUM_RUN_BENCH=1 pytest tests/test_benchmarks.py -v -s

# Or: run full test suite + benchmark sweeps together
FLASH_COLSUM_RUN_BENCH=1 pytest -v -s

For more fine-grained control (single-point runs, custom sweeps), you can also call the benchmark driver directly via python -m benchmarks.benchmark_colsum and pass flags such as --sweep {noncausal_batched,noncausal,causal,all} and --out PATH.

Citation

If you use Flash-ColSum in your research, please cite our SparseVILA paper:

@InProceedings{Khaki_2025_ICCV,
    author    = {Khaki, Samir and Guo, Junxian and Tang, Jiaming and Yang, Shang and Chen, Yukang and Plataniotis, Konstantinos N. and Lu, Yao and Han, Song and Liu, Zhijian},
    title     = {SparseVILA: Decoupling Visual Sparsity for Efficient VLM Inference},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2025},
    pages     = {23784-23794}
}

License

MIT License. See LICENSE for details.

Acknowledgments

Flash-ColSum builds on ideas from:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_colsum-0.1.0.tar.gz (12.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flash_colsum-0.1.0-py3-none-any.whl (13.3 kB view details)

Uploaded Python 3

File details

Details for the file flash_colsum-0.1.0.tar.gz.

File metadata

  • Download URL: flash_colsum-0.1.0.tar.gz
  • Upload date:
  • Size: 12.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for flash_colsum-0.1.0.tar.gz
Algorithm Hash digest
SHA256 aeffc7925d2c54be5a82616cc017ed80506dc442c8ae611941a6f23cf78c79b0
MD5 a726b505ddab568b9e14e2a2814fd479
BLAKE2b-256 0478a1e5f0e23b8966aaa2ef87bce11f82a2a0c707372bdb3a471dae41bdd074

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_colsum-0.1.0.tar.gz:

Publisher: python-publish.yml on z-lab/flash-colsum

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_colsum-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: flash_colsum-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 13.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for flash_colsum-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 299ae47a79c0c907a56dc38414281543e9aa4355298092e2c12fefb991b15c2c
MD5 0b2ce7ee7128bd9b0593b67b276dcfec
BLAKE2b-256 45a0b44d5a44a87971fbfa21541f36218a3e81a877702d90e0c8718b72b7311b

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_colsum-0.1.0-py3-none-any.whl:

Publisher: python-publish.yml on z-lab/flash-colsum

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page