Skip to main content

Fused GPU kernel for ColBERT/ColPali MaxSim scoring

Project description

Flash-MaxSim

Fused Triton GPU kernel for ColBERT/ColPali MaxSim scoring. Up to 6.5x faster at matched precision, zero memory overhead. Drop-in replacement — same API, no configuration.

pip install flash-maxsim
from flash_maxsim import flash_maxsim

scores = flash_maxsim(Q, D)  # that's it

Why Flash-MaxSim?

Every existing MaxSim implementation computes and stores the full similarity matrix in GPU memory. Flash-MaxSim eliminates it — the matrix never exists outside the chip.

1. Memory: the similarity matrix is gone

The standard einsum / bmm approach allocates B × Lq × Ld × 2 bytes for the similarity matrix. For ColPali at 10K docs, that's 21 GB — instant OOM.

ColPali (Lq=1024, Ld=1024) Naive sim matrix Flash-MaxSim
B=1,000 2,097 MB 0 MB
B=5,000 10,486 MB 0 MB
B=10,000 20,972 MB 0 MB

That's 21 GB of temporary memory just for scoring — on top of model weights, KV cache, and the embeddings themselves. On a 40 GB GPU, this OOMs. On 80 GB, it eats a quarter of your memory for a tensor that gets immediately reduced and thrown away.

Flash-MaxSim uses zero extra HBM. The similarity is computed tile-by-tile in SRAM and reduced on the fly.

2. Speed: up to 6.5x faster at matched precision (A100)

Flash-MaxSim uses FP32 accumulation — more precise than FP16 naive. For a fair comparison, naive must also reduce in FP32 (.float().max().sum()), which adds the cost of casting the sim matrix. At matched precision:

Config (A100) Naive (matched) Flash Speedup
ColBERT B=10K 1.61 ms 0.51 ms 3.2x
ColBERT B=100K 16.36 ms 4.28 ms 3.8x
ColPali B=100 1.11 ms 0.27 ms 4.0x
ColPali B=1K 10.09 ms 1.63 ms 6.2x
ColPali B=10K 100.63 ms 15.49 ms 6.5x

Even vs the fastest FP16 naive (less precise), flash is 2.5–2.9x faster:

B (ColPali) vs FP16 naive vs matched naive Sim matrix eliminated
1,000 2.6x 6.2x 2.1 GB
5,000 2.9x 6.4x 10.5 GB
10,000 2.9x 6.5x 21.0 GB

3. Zero parameters — no chunk size to tune

Production systems (vLLM, etc.) chunk documents into mini-batches to avoid OOM. Too large → OOM. Too small → launch overhead. Flash-MaxSim has zero configuration — same code on a 16 GB GPU and an 80 GB GPU.

4. Variable-length documents — zero padding waste

Real collections have variable doc lengths. Padding wastes compute. Flash-MaxSim supports packed variable-length documents:

from flash_maxsim import flash_maxsim_packed, pack_docs

D_packed, cu_seqlens, max_ld = pack_docs(doc_embeddings)
scores = flash_maxsim(Q, D_packed, doc_lengths=cu_seqlens)
Regime N Speedup Padding saved
ColBERT skewed (avg_Ld≈49) 100K 5.1x 39%
ColBERT uniform 100K 2.7x 42%
ColPali uniform 500 4.2x 37%
ColPali skewed 5K 3.9x 19%

5. INT8 index — half storage, faster, more precise

Store embeddings as INT8 (2x compression). The kernel uses INT8 tensor cores (624 TOPS on A100 — 2x FP16 throughput). No dequantization in HBM.

from flash_maxsim import flash_maxsim_int8x8, quantize_int8_symmetric

# Index time: quantize once (50% storage savings)
D_int8, scales = quantize_int8_symmetric(D)

# Query time: drop-in
scores = flash_maxsim_int8x8(Q, D_int8, scales)
Method (ColPali B=5K) Latency D Storage Extra HBM Precision
Naive dequant+einsum 30.9 ms 1 byte/dim D_fp16 copy + sim matrix 0.065
Flash FP16 8.0 ms 2 bytes/dim ~0 0.00008
Flash INT8×INT8 6.6 ms 1 byte/dim ~0 0.023

Flash INT8×INT8 is 4.7x faster than naive dequant, uses half the storage, and is 3x more precise (FP32 accumulation vs FP16 einsum).

6. Training — autograd backward pass

Full gradient support via saved argmax indices. Sparse backward — no full matrix in either direction:

from flash_maxsim import flash_maxsim_train

scores = flash_maxsim_train(Q, D)  # Q: [Lq, d], D: [B, Ld, d]
scores.sum().backward()            # gradients to both Q and D

Verified correct against naive backward (max gradient error < 0.001).

7. 800x more precise

Flash-MaxSim uses FP32 accumulation for the running max and score sum. The standard FP16 einsum has compounding rounding errors:

Method Mean error vs FP32 Top-20 overlap Spearman
FP16 naive (einsum) 6.2×10⁻² 95% 0.993
Flash FP16 7.6×10⁻⁵ 100% 1.000
Flash INT8×INT8 2.3×10⁻² 100% 0.999

Quick Start

import torch
from flash_maxsim import flash_maxsim

# Score one query against 1000 documents
Q = torch.randn(32, 128, device="cuda", dtype=torch.float16)   # query: 32 tokens
D = torch.randn(1000, 300, 128, device="cuda", dtype=torch.float16)  # 1000 docs, 300 tokens each
scores = flash_maxsim(Q, D)  # [1000]

# ColPali (long query) — automatic chunking, no configuration needed
Q_colpali = torch.randn(1024, 128, device="cuda", dtype=torch.float16)
D_colpali = torch.randn(1000, 1024, 128, device="cuda", dtype=torch.float16)
scores = flash_maxsim(Q_colpali, D_colpali)  # [1000]

# Batched: 16 queries vs same corpus (up to 15x faster than serial loop)
Q_batch = torch.randn(16, 32, 128, device="cuda", dtype=torch.float16)
scores = flash_maxsim_batched(Q_batch, D, shared_docs=True)  # [16, 1000]

Variable-Length Documents

from flash_maxsim import flash_maxsim_varlen, pack_pairs

# Each pair has different lengths — zero padding waste
q_embs = [torch.randn(32, 128, ...), torch.randn(48, 128, ...)]
d_embs = [torch.randn(180, 128, ...), torch.randn(250, 128, ...)]

Q_packed, D_packed, cu_q, cu_d, max_lq, max_ld = pack_pairs(q_embs, d_embs)
scores = flash_maxsim_varlen(Q_packed, D_packed, cu_q, cu_d, max_lq, max_ld)

INT8 Index

from flash_maxsim import flash_maxsim_int8x8, quantize_int8_symmetric

# Index time (once): 50% smaller storage
D_int8, scales = quantize_int8_symmetric(D)

# Query time: INT8 tensor cores, zero overhead
scores = flash_maxsim_int8x8(Q, D_int8, scales)

Training

from flash_maxsim import flash_maxsim_train

Q = torch.randn(32, 128, device="cuda", dtype=torch.float16, requires_grad=True)
D = torch.randn(100, 300, 128, device="cuda", dtype=torch.float16)
scores = flash_maxsim_train(Q, D)
loss = scores.sum()
loss.backward()  # Q.grad computed via sparse argmax backward

Zero-Copy Reranking

Score documents directly from a model's output tensor — zero additional memory:

from flash_maxsim import flash_maxsim_rerank_direct

scores = flash_maxsim_rerank_direct(
    Q, batch_tensor, doc_offsets, doc_lengths, max_ld
)  # 0 bytes allocated for scoring

How It Works

Q_block = load(Q)                      # SRAM (small — one query)
m = [-inf] * Lq                        # registers (running max per query token)

for tile in D.tiles(BLOCK_D):
    D_tile = load(tile)                # SRAM
    S = tl.dot(Q_block, D_tile.T)     # tensor cores — stays in SRAM
    m = max(m, S.max(axis=1))         # online max reduction
    # S dies here — never written to HBM

score = sum(m)                          # one scalar per doc → HBM

Same principle as Flash Attention, but simpler: max is trivially composable across tiles (no log-sum-exp rescaling needed).

API Reference

Core Scoring

Function Signature Description
flash_maxsim [Lq,d] × [B,Ld,d] → [B] Single query, auto-chunking for long queries
flash_maxsim_batched [Nq,Lq,d] × [B,Ld,d] → [Nq,B] Multi-query (shared or per-query docs)
flash_maxsim_varlen packed Q,D + cu_seqlens → [N] Variable-length pairs, zero padding
flash_maxsim_packed [Lq,d] × packed [T,d] + cu_seqlens → [B] Shared Q + variable-length packed D

INT8 Quantization

Function Description
flash_maxsim_int8x8 True INT8×INT8 tensor core scoring (recommended)
quantize_int8_symmetric Per-token symmetric INT8 quantization for D
quantize_query_int8 Per-token INT8 quantization for Q
flash_maxsim_int8 Legacy: fused affine INT8 dequant+scoring

Training & Utilities

Function Description
flash_maxsim_train MaxSim with autograd backward (sparse argmax)
flash_maxsim_rerank_direct Zero-copy scoring from scattered batch tensor
pack_pairs Pack variable-length (Q, D) pairs into cu_seqlens format
pack_docs Pack variable-length docs for flash_maxsim_packed
maxsim_naive Pure PyTorch reference (FP16 einsum)

Requirements

  • NVIDIA GPU (Ampere or newer recommended)
  • PyTorch >= 2.0
  • Triton >= 3.4
  • CUDA

Tested on: H100 80GB, A100 80GB/40GB, V100.

Authors

IBM Research Israel

  • Roi Pony
  • Adi Raz Goldfarb
  • Idan Friedman
  • Udi Barzelay

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_maxsim-0.2.0.tar.gz (32.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flash_maxsim-0.2.0-py3-none-any.whl (35.1 kB view details)

Uploaded Python 3

File details

Details for the file flash_maxsim-0.2.0.tar.gz.

File metadata

  • Download URL: flash_maxsim-0.2.0.tar.gz
  • Upload date:
  • Size: 32.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.19

File hashes

Hashes for flash_maxsim-0.2.0.tar.gz
Algorithm Hash digest
SHA256 c17322a38d3748a3875c50fdbafcb050595258f2cffd9a14873353fa89d54e8c
MD5 51434ffbc3d9860ae2923e7af3989b0d
BLAKE2b-256 9c469b7eae40fe0c737d7b60573cb9017333394def60dceaad8919d6571a94fc

See more details on using hashes here.

File details

Details for the file flash_maxsim-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: flash_maxsim-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 35.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.19

File hashes

Hashes for flash_maxsim-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 69a990615d0e23ace424ac8218138037048a58cf3d64335cbe65cf707dcdc4e2
MD5 136809153ff9b946ff0233b8cc976fcd
BLAKE2b-256 a8e6c7372ac39448c534f5907b8575b6f0e0c4dbbd28125e2b74e4399a54f547

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page