Fused Triton kernels for late-interaction (MaxSim) scoring — ColBERT, ColPali, ModernColBERT.
Project description
late-interaction-kernels
[How it works] [Kernel picker] [Benchmarks] [Design] [Supported models] [Changelog]
[!NOTE] Full algorithmic walkthrough, animations and benchmark plots live on the docs site: hcompai.github.io/late-interaction-kernels.
Introduction
late-interaction-kernels provides fused Triton kernels for MaxSim, the late-interaction scoring used by ColBERT, ColPali, ModernColBERT, LateOn and ColBERTv2. The kernels are numerically identical to plain PyTorch and come with three APIs:
- a one-line PyLate drop-in (
patch_pylate()), - a stateless
nn.Module(MaxSimScorer) for custom training loops, - function-level entry points (
maxsim,maxsim_varlen,maxsim_padded, ...) for everything else.
This is not a search engine. For end-to-end training or retrieval use PyLate, FastPlaid or NextPlaid. This library is the MaxSim math they compile down to.
Install
pip install late-interaction-kernels
| Platform | Backend |
|---|---|
| Linux + CUDA (sm_75+) | Fused Triton kernels (autotuned, FP8 on Hopper). |
| macOS (Apple Silicon, MPS) | Fused Metal simdgroup_matrix kernels for inference and training (fp16 / bf16, d ≤ 128); torch.compile fallback otherwise. |
| CPU / Windows | Autograd-aware pure-PyTorch reference. |
Quickstart
Score directly (maxsim / maxsim_pairs)
maxsim is the lowest-level public entry point — autograd-aware, mask-aware, and dispatches on D.dim() so the same call covers in-batch and knowledge-distillation layouts in one fused launch. The argmax buffer for the backward is skipped automatically when neither input has requires_grad=True, so the same function is the inference path too.
from late_interaction_kernels import maxsim, maxsim_pairs
# in-batch: Q[Nq, Lq, d] × D[Nd, Ld, d] → [Nq, Nd]
scores = maxsim(Q, D, q_mask=q_mask, d_mask=d_mask, normalize=True)
# KD / hard-negative: D is 4D [Nq, K, Ld, d] → [Nq, K]
# Single launch, no Python loop, no [Nq, Nq] cross product.
scores = maxsim(Q, D_kd, q_mask=q_mask, d_mask=d_mask_kd)
# pairwise (diagonal): Q[B, Lq, d] × D[B, Ld, d] → [B]
scores = maxsim_pairs(Q, D, q_mask=q_mask, d_mask=d_mask)
Top-k retrieval
Score Q against a large corpus and return the top-k per query without materialising the full [Nq, Nd] matrix — chunk= streams documents in tiles so peak HBM stays bounded.
from late_interaction_kernels import retrieve
scores, indices = retrieve(Q, D, top_k=100, chunk=4096)
# both [Nq, 100]; chunk= bounds peak HBM at Nq * (chunk + top_k)
PLAID / ColBERTv2 on compressed, ragged docs
For PLAID-style indexes where documents are stored as centroid codes + residuals at variable lengths. A single kernel fuses decompression, L2-normalisation and MaxSim — no decoded tensor is ever written back to HBM.
from late_interaction_kernels.plaid import maxsim_residual_varlen
scores = maxsim_residual_varlen(
Q, codes_flat, residuals_flat, cu_seqlens_d,
centroids=centroids, bucket_weights=bucket_weights,
nbits=2, normalize=True,
) # [Nd] fp32; one kernel does decompress + L2-normalize + MaxSim
Custom training loop
A stateless nn.Module wrapper around maxsim — drop it into any training loop that needs autograd-aware late-interaction scoring without touching PyLate.
from late_interaction_kernels import MaxSimScorer
scorer = MaxSimScorer(normalize=True) # nn.Module, no parameters
scores = scorer(Q, D, q_mask=q_mask, d_mask=d_mask) # [Nq, Nd] fp32
scores.mean().backward()
Patch PyLate (one line)
Monkey-patches PyLate's scoring + loss to route through the fused kernel. Existing PyLate training and rerank scripts run unchanged; set LIK_DISABLE=1 to fall back to vanilla PyLate at runtime.
from late_interaction_kernels import patch_pylate
patch_pylate()
# PyLate training / rerank code is unchanged
Benchmarks
1×H100 80GB SXM, bf16 inputs / fp32 accumulator, 50-iter median. All
speedups are measured at matched numerics — every baseline runs the
einsum with an fp32 accumulator (same as the fused kernel), and parity
is asserted at atol=1e-2 before timing.
| Workload | Speedup |
|---|---|
Reranking / inference (vs eager fp32-acc and torch.compile) |
1.7-15× |
Long-context (Ld ≥ 8k) MaxSim fwd+bwd |
runs; naive OOMs |
| PyLate cached-contrastive MaxSim + backward (vs vanilla) | 5-6.5× |
PLAID rerank vs fast_plaid.engine.search() (incl. top-k) |
8-23× full / 18-51× partial |
| Fused D-side head (training) | 1.5-4.5× on Nd · Ld large |
| FP8 MaxSim inference vs same kernel in bf16 (Hopper) | 1.1-1.3× on Ld ≥ 256 |
| LateOn-Code-edge training (real MS MARCO triplets) | 1.00-1.06× e2e |
torch.compile is within ±5% of eager on every forward shape because
Inductor still has to materialise the [Nq · Nd · Lq · Ld] similarity
tensor before the max(-1) reduction — that materialisation is what
the fused kernel exists to skip. Full tables and reproduction commands
live in docs/benchmarks.md; for how the bench
scripts themselves are organised — CLI conventions (--only,
--variants), per-script summaries, and how to run one bench, the
whole sweep, or a RUN_ONLY-filtered subset on a SkyPilot cluster —
see benchmarks/README.md.
Choose a kernel
|
Not sure which entry point fits your stack? The docs site ships an interactive decision tree that narrows the public API down to the right function in four questions (stack · phase · layout · workload): 👉 hcompai.github.io/late-interaction-kernels/choose-a-kernel.html |
|
API
| Symbol | What it does |
|---|---|
patch_pylate() / unpatch_pylate() |
One-line PyLate drop-in. LIK_DISABLE=1 kill switch. |
patch_colpali_engine() / unpatch_colpali_engine() |
One-line colpali_engine drop-in (loss + scoring route through the kernel). |
MaxSimScorer(normalize=, backward=) |
Stateless nn.Module, autograd-aware. |
retrieve(Q, D, top_k, chunk=) |
Top-k retrieval, chunked for huge corpora. |
maxsim |
Core MaxSim. Dispatches on D.dim(): 3D → in-batch [Nq, Nd], 4D → per-query KD candidates [Nq, K] (one fused launch, no Python loop). Autograd-aware. |
maxsim_pairs |
Diagonal pairs Q[B, Lq, d] × D[B, Ld, d] → [B]. K=1 case of the KD path; never builds the [B, B] cross product. Autograd-aware. |
maxsim_varlen |
Packed (cu_seqlens) layout. Autograd-aware. |
maxsim_padded |
Padded reranking wrapper: packs internally, returns [B, C] fp32. |
Other kernels are in submodules: padded, score_pairs, fused_head, plaid, fp8, reference. See docs/design.md for details on every kernel, the autograd graph and the backward variants.
🔽 Configuration knobs (env vars + kwargs)
| Knob | Effect |
|---|---|
maxsim(..., backward="auto" | "unified" | "atomic" | "csr") |
Per-call grad_D strategy. "auto" picks per shape. |
LIK_DISABLE=1 |
Patched entry points delegate to vanilla PyLate / colpali_engine. |
LIK_SUPPRESS_NORM_WARN=1 |
Silence the "looks unnormalized" one-shot warning. |
LIK_DISABLE_COMPILE=1 |
Skip torch.compile on the MPS path (eager fallback). |
LIK_FORCE_MPS_BACKEND={metal,compile,reference} |
Pin the MPS dispatch. |
Development
git clone https://github.com/hcompai/late-interaction-kernels
cd late-interaction-kernels
uv sync --extra dev --extra pylate --extra torch-cuda # GPU dev; use --extra torch-cpu on CPU-only boxes
uv run pytest -q # CUDA tests auto-skip without a GPU
uv run ruff check . && uv run ruff format --check .
[!NOTE] Pick exactly one of
--extra torch-cuda(pulls torch from the CUDA index —cu124) or--extra torch-cpu(CPU-only wheel, what CI uses). The two are declared as conflicting inpyproject.tomlso the lockfile resolves cleanly for both. On macOS,--extra torch-cpufalls back to PyPI's default (MPS-capable) wheel automatically.
GPU tests run on AWS CodeBuild (A10G). They do not fire on pushes to main (CodeBuild spend); they run automatically on v* tag pushes and on PRs carrying the run-gpu-tests label (applying the label requires triage+, so ping a maintainer if your PR needs it). Maintainers can also trigger an on-demand run via the GPU CI workflow workflow_dispatch.
See CONTRIBUTING.md for the contribution workflow.
Related projects
- roipony/flash-maxsim — fused Triton kernel that tiles the similarity matrix in SRAM instead of materialising it in HBM.
- erikkaum/maxsim — exact MaxSim with hand-written CUDA (NVIDIA) and Metal (Apple Silicon) kernels; avoids materialising the similarity matrix on either backend.
- mixedbread-ai/maxsim-cpu — Rust + SIMD CPU implementation (libxsmm on x86, Accelerate on ARM) for environments without a GPU.
Citation
@software{late_interaction_kernels_2026,
author = {Lac, Aurélien and Wu, Tony},
title = {{late-interaction-kernels}: Fused Triton kernels for late-interaction scoring},
year = {2026},
url = {https://github.com/hcompai/late-interaction-kernels},
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file late_interaction_kernels-0.3.0.tar.gz.
File metadata
- Download URL: late_interaction_kernels-0.3.0.tar.gz
- Upload date:
- Size: 611.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ac10b20f36c95dde0b9b663ea9fb961fcd0a7123da7f6404ed535cf41cc900a2
|
|
| MD5 |
5e0f0aa91eb988d51818dfc66d3c71a2
|
|
| BLAKE2b-256 |
d5bbf90a3b77737fe913de8d69d8e893af0409d0359deb0543956b7f44daeabe
|
Provenance
The following attestation bundles were made for late_interaction_kernels-0.3.0.tar.gz:
Publisher:
publish.yml on hcompai/late-interaction-kernels
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
late_interaction_kernels-0.3.0.tar.gz -
Subject digest:
ac10b20f36c95dde0b9b663ea9fb961fcd0a7123da7f6404ed535cf41cc900a2 - Sigstore transparency entry: 1656649257
- Sigstore integration time:
-
Permalink:
hcompai/late-interaction-kernels@95ed05b818c8dca1567ca36100327941c99a3f4c -
Branch / Tag:
refs/tags/v0.3.0 - Owner: https://github.com/hcompai
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@95ed05b818c8dca1567ca36100327941c99a3f4c -
Trigger Event:
release
-
Statement type:
File details
Details for the file late_interaction_kernels-0.3.0-py3-none-any.whl.
File metadata
- Download URL: late_interaction_kernels-0.3.0-py3-none-any.whl
- Upload date:
- Size: 95.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7edd5a5372be1d83fc335452cd78f98b10f526a5081f1c48e05536018a5be741
|
|
| MD5 |
bd380b51202a46748688c43eacb39991
|
|
| BLAKE2b-256 |
ad55ae9a1f31c077e92d870bf8a59633565c2468063532bcf2b86b7e9f75f63f
|
Provenance
The following attestation bundles were made for late_interaction_kernels-0.3.0-py3-none-any.whl:
Publisher:
publish.yml on hcompai/late-interaction-kernels
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
late_interaction_kernels-0.3.0-py3-none-any.whl -
Subject digest:
7edd5a5372be1d83fc335452cd78f98b10f526a5081f1c48e05536018a5be741 - Sigstore transparency entry: 1656649402
- Sigstore integration time:
-
Permalink:
hcompai/late-interaction-kernels@95ed05b818c8dca1567ca36100327941c99a3f4c -
Branch / Tag:
refs/tags/v0.3.0 - Owner: https://github.com/hcompai
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@95ed05b818c8dca1567ca36100327941c99a3f4c -
Trigger Event:
release
-
Statement type: