Fused Triton kernels for late-interaction (MaxSim) scoring — ColBERT, ColPali, ModernColBERT.
Project description
⚡ late-interaction-kernels
Fused Triton kernels for MaxSim scoring. ColBERT · ColPali · ModernColBERT · LateOn · LateOn-Code · ColBERTv2 · PyLate-native.
Install · Quickstart · Speedups · API · Benchmarks · Design · Models
Drop-in, numerically-identical replacements for the MaxSim math behind
late-interaction training, reranking and retrieval. One line patches
PyLate; the rest is nn.Module and function-level APIs for custom
pipelines.
Not a search engine. For end-to-end retrieval use FastPlaid, NextPlaid / ColGrep, or PyLate. This library is the MaxSim math their reranking / training compiles down to.
Install
pip install late-interaction-kernels
Linux + CUDA for the fused kernels. macOS / Windows / CPU still works —
MaxSimScorer, retrieve and late_interaction_kernels.reference fall
back to a pure-PyTorch implementation, so training and retrieval code
runs locally before renting a GPU. The PyLate drop-in targets PyLate
≥ 1.3.
Quickstart
Speed up PyLate, one line:
from late_interaction_kernels import patch_pylate
patch_pylate()
# PyLate training / rerank code is unchanged
LIK_DISABLE=1 in the environment makes patched entry points fall back
to vanilla PyLate at runtime.
Score MaxSim in any training loop:
from late_interaction_kernels import MaxSimScorer
scorer = MaxSimScorer(normalize=True) # nn.Module, no parameters
scores = scorer(Q, D, q_mask=q_mask, d_mask=d_mask) # [Nq, Nd] fp32
scores.mean().backward()
Top-k retrieval:
from late_interaction_kernels import retrieve
scores, indices = retrieve(Q, D, top_k=100, chunk=4096)
# both [Nq, 100] — chunk= bounds peak HBM at Nq · (chunk + top_k)
PLAID / ColBERTv2 rerank on compressed, ragged docs:
from late_interaction_kernels import maxsim_residual_varlen
# fast-plaid / ColBERTv2 on-disk layout
scores = maxsim_residual_varlen(
Q, codes_flat, residuals_flat, cu_seqlens_d,
centroids=centroids, bucket_weights=bucket_weights,
nbits=2, normalize=True,
) # [Nd] fp32 — one kernel does decompress + L2-normalize + MaxSim
Speedups on H100
1×H100 80 GB SXM, bf16 / fp16 compute, fp32 accumulator, 50-iter median. Every baseline is the same operation in plain PyTorch.
| Workload | Speedup |
|---|---|
| Reranking / inference vs naive einsum | 7–23× |
Long-context (Ld ≥ 8k) reranking |
runs; naive OOMs |
| PyLate cached-contrastive MaxSim + backward | up to 13.8× |
PLAID rerank vs fast_plaid.engine.search() |
19–30× |
| Fused D-side head (training) | 1.5–4.6× |
| FP8 MaxSim inference (Hopper) | up to 1.4× |
| LateOn-Code-edge end-to-end training (17 M) | 1.04–1.27× |
| LateOn / ModernColBERT end-to-end training (149 M) | 1.00–1.06× (free) |
ModernBERT-class encoders dominate step time, so on a full 149 M training run the kernel is essentially a free swap. Wherever MaxSim stops being negligible — inference, reranking, long docs, big effective batch, KD, compressed indices, small encoders — the fused path moves.
Full tables, shapes and reproduction commands:
docs/benchmarks.md.
API
Most users only need patch_pylate(), MaxSimScorer or retrieve.
| Symbol | What it does |
|---|---|
patch_pylate() / unpatch_pylate() |
One-line PyLate drop-in. LIK_DISABLE=1 kill switch. |
MaxSimScorer(normalize=, backward=) |
Stateless nn.Module, autograd-aware. |
retrieve(Q, D, top_k, chunk=) |
Top-k retrieval, chunked for huge corpora. |
maxsim / maxsim_inference |
Core MaxSim, dense layout (autograd / forward-only). |
maxsim_varlen |
Packed (cu_seqlens) layout. Autograd-aware. |
maxsim_inference_scatter |
Pair-list reranking on packed batches (vLLM-style scheduling). |
maxsim_from_hidden(_train) |
Fused D-side Linear → Normalize → MaxSim, no [Nd, Ld, d_out] scratch. |
maxsim_residual(_varlen) |
Fused PLAID / ColBERTv2 decompress + normalize + MaxSim. |
plaid_approx_score |
IVF prune step for ColBERTv2. |
maxsim_inference_fp8 |
FP8 tensor-core MaxSim (Hopper / Blackwell). Auto-fallback to bf16. |
Submodules:
late_interaction_kernels.fp8— FP8 quantize / dequantize helpers (per-tensor and per-token).late_interaction_kernels.experimental— research kernels (soft_maxsim,smooth_maxsim,maxsim_matryoshka,maxsim_xtr).late_interaction_kernels.reference— pure-PyTorch reference implementations, importable on any platform.
Config:
| Knob | Effect |
|---|---|
maxsim(..., backward="auto" | "unified" | "atomic" | "csr") |
Per-call grad_D strategy. "auto" picks per shape. |
set_backward_method(...) / get_backward_method() |
Process-wide default (back-compat; prefer per-call kwarg). |
LIK_DISABLE=1 |
Patched entry points delegate to vanilla PyLate. |
LIK_SUPPRESS_NORM_WARN=1 |
Silence the "looks unnormalized" one-shot warning. |
Walk-through of every kernel, the autograd graph, the backward variants
and the numerics: docs/design.md.
Hardware
Primary target: H100 / H200 (autotuned, FP8 WGMMA, warp-specialized on Triton ≥ 3.2). Also tuned for A100, Ada (L4 / L40 / 4090) and Ampere (A10 / A40 / 3090). Older / unknown CUDA falls back to a conservative shortlist. CPU / macOS / Windows get the pure-PyTorch reference.
Autotune runs once per unique (Lq, Ld, d, masks) signature and caches
the winner — zero overhead after warmup.
Development
git clone https://github.com/hcompai/late-interaction-kernels
cd late-interaction-kernels
pip install -e ".[dev,pylate]"
pytest -q # auto-skips CUDA tests without a GPU
ruff check . && ruff format --check .
python benchmarks/bench_forward.py # see benchmarks/ for the full set
CONTRIBUTING.md for the contribution workflow.
CHANGELOG.md for the kernel-by-kernel release history.
Citation
@software{late_interaction_kernels_2026,
author = {Lac, Aurélien and Wu, Tony},
title = {{late-interaction-kernels}: Fused Triton kernels for late-interaction scoring},
year = {2026},
url = {https://github.com/hcompai/late-interaction-kernels},
}
Aurélien Lac · Tony Wu — H Company · 2026
License
Apache 2.0 — see LICENSE. Copyright 2026 Aurélien Lac and Tony Wu.
Related
PyLate · FastPlaid · NextPlaid / ColGrep · flash-maxsim (IBM, the first public Triton MaxSim — direct inspiration) · FlashAttention.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file late_interaction_kernels-0.0.1.tar.gz.
File metadata
- Download URL: late_interaction_kernels-0.0.1.tar.gz
- Upload date:
- Size: 425.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
91f3da411f4eba8636dd84dd1a170ff0144972b356cbccd635928101f935b714
|
|
| MD5 |
192e9282db5ac001df3d5dc4a10d408f
|
|
| BLAKE2b-256 |
60952f90621a9233215feedfc69014b1e99d533b2d0b5adeac0285d59b4b1140
|
Provenance
The following attestation bundles were made for late_interaction_kernels-0.0.1.tar.gz:
Publisher:
publish.yml on hcompai/late-interaction-kernels
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
late_interaction_kernels-0.0.1.tar.gz -
Subject digest:
91f3da411f4eba8636dd84dd1a170ff0144972b356cbccd635928101f935b714 - Sigstore transparency entry: 1429485180
- Sigstore integration time:
-
Permalink:
hcompai/late-interaction-kernels@9211a0112065da9549f44f27c228eea48ff3cc43 -
Branch / Tag:
refs/tags/v0.0.1 - Owner: https://github.com/hcompai
-
Access:
private
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@9211a0112065da9549f44f27c228eea48ff3cc43 -
Trigger Event:
release
-
Statement type:
File details
Details for the file late_interaction_kernels-0.0.1-py3-none-any.whl.
File metadata
- Download URL: late_interaction_kernels-0.0.1-py3-none-any.whl
- Upload date:
- Size: 75.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3601b2f48e90089ac92920643bf81019eacd71f6dbf0c825720c85afd0c1a046
|
|
| MD5 |
a5a41b167ffd92d67834c8bd9d498aa9
|
|
| BLAKE2b-256 |
7788ef4be038da927dc1bc9308cff7616b8be03ceb1c9e754eaddc6d8728d66e
|
Provenance
The following attestation bundles were made for late_interaction_kernels-0.0.1-py3-none-any.whl:
Publisher:
publish.yml on hcompai/late-interaction-kernels
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
late_interaction_kernels-0.0.1-py3-none-any.whl -
Subject digest:
3601b2f48e90089ac92920643bf81019eacd71f6dbf0c825720c85afd0c1a046 - Sigstore transparency entry: 1429485181
- Sigstore integration time:
-
Permalink:
hcompai/late-interaction-kernels@9211a0112065da9549f44f27c228eea48ff3cc43 -
Branch / Tag:
refs/tags/v0.0.1 - Owner: https://github.com/hcompai
-
Access:
private
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@9211a0112065da9549f44f27c228eea48ff3cc43 -
Trigger Event:
release
-
Statement type: