Skip to main content

Fused GPU kernel for ColBERT/ColPali MaxSim scoring

Project description

Flash-MaxSim

Fused Triton GPU kernel for ColBERT/ColPali MaxSim scoring. Up to 6.5x faster at matched precision, zero memory overhead. Drop-in replacement — same API, no configuration.

pip install flash-maxsim
from flash_maxsim import flash_maxsim

scores = flash_maxsim(Q, D)  # that's it

Why Flash-MaxSim?

Every existing MaxSim implementation computes and stores the full similarity matrix in GPU memory. Flash-MaxSim eliminates it — the matrix never exists outside the chip.

1. Memory: the similarity matrix is gone

The standard einsum / bmm approach allocates B × Lq × Ld × 2 bytes for the similarity matrix. For ColPali at 10K docs, that's 21 GB — instant OOM.

ColPali (Lq=1024, Ld=1024) Naive sim matrix Flash-MaxSim
B=1,000 2,097 MB 0 MB
B=5,000 10,486 MB 0 MB
B=10,000 20,972 MB 0 MB

That's 21 GB of temporary memory just for scoring — on top of model weights, KV cache, and the embeddings themselves. On a 40 GB GPU, this OOMs. On 80 GB, it eats a quarter of your memory for a tensor that gets immediately reduced and thrown away.

Flash-MaxSim uses zero extra HBM. The similarity is computed tile-by-tile in SRAM and reduced on the fly.

2. Speed: up to 6.5x faster at matched precision (A100)

Flash-MaxSim uses FP32 accumulation — more precise than FP16 naive. For a fair comparison, naive must also reduce in FP32 (.float().max().sum()), which adds the cost of casting the sim matrix. At matched precision:

Config (A100) Naive (matched) Flash Speedup
ColBERT B=10K 1.61 ms 0.51 ms 3.2x
ColBERT B=100K 16.36 ms 4.28 ms 3.8x
ColPali B=100 1.11 ms 0.27 ms 4.0x
ColPali B=1K 10.09 ms 1.63 ms 6.2x
ColPali B=10K 100.63 ms 15.49 ms 6.5x

Even vs the fastest FP16 naive (less precise), flash is 2.5–2.9x faster:

B (ColPali) vs FP16 naive vs matched naive Sim matrix eliminated
1,000 2.6x 6.2x 2.1 GB
5,000 2.9x 6.4x 10.5 GB
10,000 2.9x 6.5x 21.0 GB

3. Zero parameters — no chunk size to tune

Production systems (vLLM, etc.) chunk documents into mini-batches to avoid OOM. Too large → OOM. Too small → launch overhead. Flash-MaxSim has zero configuration — same code on a 16 GB GPU and an 80 GB GPU.

4. Variable-length documents — zero padding waste

Real collections have variable doc lengths. Padding wastes compute. Flash-MaxSim supports packed variable-length documents:

from flash_maxsim import flash_maxsim_packed, pack_docs

D_packed, cu_seqlens, max_ld = pack_docs(doc_embeddings)
scores = flash_maxsim(Q, D_packed, doc_lengths=cu_seqlens)
Regime N Speedup Padding saved
ColBERT skewed (avg_Ld≈49) 100K 5.1x 39%
ColBERT uniform 100K 2.7x 42%
ColPali uniform 500 4.2x 37%
ColPali skewed 5K 3.9x 19%

5. INT8 index — half storage, faster, more precise

Store embeddings as INT8 (2x compression). The kernel uses INT8 tensor cores (624 TOPS on A100 — 2x FP16 throughput). No dequantization in HBM.

from flash_maxsim import flash_maxsim_int8x8, quantize_int8_symmetric

# Index time: quantize once (50% storage savings)
D_int8, scales = quantize_int8_symmetric(D)

# Query time: drop-in
scores = flash_maxsim_int8x8(Q, D_int8, scales)
Method (ColPali B=5K) Latency D Storage Extra HBM Precision
Naive dequant+einsum 30.9 ms 1 byte/dim D_fp16 copy + sim matrix 0.065
Flash FP16 8.0 ms 2 bytes/dim ~0 0.00008
Flash INT8×INT8 6.6 ms 1 byte/dim ~0 0.023

Flash INT8×INT8 is 4.7x faster than naive dequant, uses half the storage, and is 3x more precise (FP32 accumulation vs FP16 einsum).

6. Training — autograd backward pass

Full gradient support via saved argmax indices. Sparse backward — no full matrix in either direction:

from flash_maxsim import flash_maxsim_train

scores = flash_maxsim_train(Q, D)  # Q: [Lq, d], D: [B, Ld, d]
scores.sum().backward()            # gradients to both Q and D

Verified correct against naive backward (max gradient error < 0.001).

7. 800x more precise

Flash-MaxSim uses FP32 accumulation for the running max and score sum. The standard FP16 einsum has compounding rounding errors:

Method Mean error vs FP32 Top-20 overlap Spearman
FP16 naive (einsum) 6.2×10⁻² 95% 0.993
Flash FP16 7.6×10⁻⁵ 100% 1.000
Flash INT8×INT8 2.3×10⁻² 100% 0.999

Quick Start

import torch
from flash_maxsim import flash_maxsim

# Score one query against 1000 documents
Q = torch.randn(32, 128, device="cuda", dtype=torch.float16)   # query: 32 tokens
D = torch.randn(1000, 300, 128, device="cuda", dtype=torch.float16)  # 1000 docs, 300 tokens each
scores = flash_maxsim(Q, D)  # [1000]

# ColPali (long query) — automatic chunking, no configuration needed
Q_colpali = torch.randn(1024, 128, device="cuda", dtype=torch.float16)
D_colpali = torch.randn(1000, 1024, 128, device="cuda", dtype=torch.float16)
scores = flash_maxsim(Q_colpali, D_colpali)  # [1000]

# Batched: 16 queries vs same corpus (up to 15x faster than serial loop)
Q_batch = torch.randn(16, 32, 128, device="cuda", dtype=torch.float16)
scores = flash_maxsim_batched(Q_batch, D, shared_docs=True)  # [16, 1000]

Variable-Length Documents

from flash_maxsim import flash_maxsim_varlen, pack_pairs

# Each pair has different lengths — zero padding waste
q_embs = [torch.randn(32, 128, ...), torch.randn(48, 128, ...)]
d_embs = [torch.randn(180, 128, ...), torch.randn(250, 128, ...)]

Q_packed, D_packed, cu_q, cu_d, max_lq, max_ld = pack_pairs(q_embs, d_embs)
scores = flash_maxsim_varlen(Q_packed, D_packed, cu_q, cu_d, max_lq, max_ld)

INT8 Index

from flash_maxsim import flash_maxsim_int8x8, quantize_int8_symmetric

# Index time (once): 50% smaller storage
D_int8, scales = quantize_int8_symmetric(D)

# Query time: INT8 tensor cores, zero overhead
scores = flash_maxsim_int8x8(Q, D_int8, scales)

Training

from flash_maxsim import flash_maxsim_train

Q = torch.randn(32, 128, device="cuda", dtype=torch.float16, requires_grad=True)
D = torch.randn(100, 300, 128, device="cuda", dtype=torch.float16)
scores = flash_maxsim_train(Q, D)
loss = scores.sum()
loss.backward()  # Q.grad computed via sparse argmax backward

Zero-Copy Reranking

Score documents directly from a model's output tensor — zero additional memory:

from flash_maxsim import flash_maxsim_rerank_direct

scores = flash_maxsim_rerank_direct(
    Q, batch_tensor, doc_offsets, doc_lengths, max_ld
)  # 0 bytes allocated for scoring

How It Works

Q_block = load(Q)                      # SRAM (small — one query)
m = [-inf] * Lq                        # registers (running max per query token)

for tile in D.tiles(BLOCK_D):
    D_tile = load(tile)                # SRAM
    S = tl.dot(Q_block, D_tile.T)     # tensor cores — stays in SRAM
    m = max(m, S.max(axis=1))         # online max reduction
    # S dies here — never written to HBM

score = sum(m)                          # one scalar per doc → HBM

Same principle as Flash Attention, but simpler: max is trivially composable across tiles (no log-sum-exp rescaling needed).

API Reference

Core Scoring

Function Signature Description
flash_maxsim [Lq,d] × [B,Ld,d] → [B] Single query, auto-chunking for long queries
flash_maxsim_batched [Nq,Lq,d] × [B,Ld,d] → [Nq,B] Multi-query (shared or per-query docs)
flash_maxsim_varlen packed Q,D + cu_seqlens → [N] Variable-length pairs, zero padding
flash_maxsim_packed [Lq,d] × packed [T,d] + cu_seqlens → [B] Shared Q + variable-length packed D

INT8 Quantization

Function Description
flash_maxsim_int8x8 True INT8×INT8 tensor core scoring (recommended)
quantize_int8_symmetric Per-token symmetric INT8 quantization for D
quantize_query_int8 Per-token INT8 quantization for Q
flash_maxsim_int8 Legacy: fused affine INT8 dequant+scoring

Training & Utilities

Function Description
flash_maxsim_train MaxSim with autograd backward (sparse argmax)
flash_maxsim_rerank_direct Zero-copy scoring from scattered batch tensor
pack_pairs Pack variable-length (Q, D) pairs into cu_seqlens format
pack_docs Pack variable-length docs for flash_maxsim_packed
maxsim_naive Pure PyTorch reference (FP16 einsum)

Requirements

  • NVIDIA GPU (Ampere or newer recommended)
  • PyTorch >= 2.0
  • Triton >= 3.4
  • CUDA

Tested on: H100 80GB, A100 80GB/40GB, V100.

Authors

IBM Research Israel

  • Roi Pony
  • Adi Raz Goldfarb
  • Idan Friedman
  • Udi Barzelay

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_maxsim-0.2.1.tar.gz (70.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flash_maxsim-0.2.1-py3-none-any.whl (66.7 kB view details)

Uploaded Python 3

File details

Details for the file flash_maxsim-0.2.1.tar.gz.

File metadata

  • Download URL: flash_maxsim-0.2.1.tar.gz
  • Upload date:
  • Size: 70.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.19

File hashes

Hashes for flash_maxsim-0.2.1.tar.gz
Algorithm Hash digest
SHA256 37eef106f1e307dbcfc797d069a7c48e2ddbec7dcbf614ee89307900ce5b3cfa
MD5 8af49817cfcae1debc01ba34a4e961e2
BLAKE2b-256 4db0c050abb1752a358a5841f8bf120839557da663583e1251bdf1de918ba70d

See more details on using hashes here.

File details

Details for the file flash_maxsim-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: flash_maxsim-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 66.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.19

File hashes

Hashes for flash_maxsim-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 73c0a269dc254b96bfbb84a77680231f8654cfafa9ccb1d26fa3f1a14912d195
MD5 a3c41744038f3438eb1b7bdec80b592c
BLAKE2b-256 04c4ec76cb1a24efb0f93f633434b0adc1caa0d26d0bf3fe8a3c549af3a91274

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page