Fused Triton kernels for late-interaction (MaxSim) scoring — ColBERT, ColPali, ModernColBERT.
Project description
[!NOTE] Full algorithmic walkthrough, animations and benchmark plots live on the docs site: hcompai.github.io/late-interaction-kernels.
Introduction
late-interaction-kernels provides fused Triton kernels for MaxSim, the late-interaction scoring used by ColBERT, ColPali, ModernColBERT, LateOn and ColBERTv2. The kernels are numerically identical to plain PyTorch and come with three APIs:
- a one-line PyLate drop-in (
patch_pylate()), - a stateless
nn.Module(MaxSimScorer) for custom training loops, - function-level entry points (
maxsim,maxsim_varlen,maxsim_padded, ...) for everything else.
This is not a search engine. For end-to-end training or retrieval use PyLate, FastPlaid or NextPlaid. This library is the MaxSim math they compile down to.
Install
pip install late-interaction-kernels
| Platform | Backend |
|---|---|
| Linux + CUDA (sm_75+) | Fused Triton kernels (autotuned, FP8 on Hopper). |
| macOS (Apple Silicon, MPS) | Fused Metal simdgroup_matrix for inference, torch.compile for training. |
| CPU / Windows | Autograd-aware pure-PyTorch reference. |
[!NOTE] The PyLate drop-in targets PyLate >= 1.3. The pure-PyTorch reference imports on every platform, so training and retrieval code is unit-testable on a laptop before you rent a GPU.
Quickstart
Patch PyLate (one line)
from late_interaction_kernels import patch_pylate
patch_pylate()
# PyLate training / rerank code is unchanged
Set LIK_DISABLE=1 in the environment to fall back to vanilla PyLate at runtime.
Custom training loop
from late_interaction_kernels import MaxSimScorer
scorer = MaxSimScorer(normalize=True) # nn.Module, no parameters
scores = scorer(Q, D, q_mask=q_mask, d_mask=d_mask) # [Nq, Nd] fp32
scores.mean().backward()
Top-k retrieval
from late_interaction_kernels import retrieve
scores, indices = retrieve(Q, D, top_k=100, chunk=4096)
# both [Nq, 100]; chunk= bounds peak HBM at Nq * (chunk + top_k)
PLAID / ColBERTv2 on compressed, ragged docs
from late_interaction_kernels.plaid import maxsim_residual_varlen
scores = maxsim_residual_varlen(
Q, codes_flat, residuals_flat, cu_seqlens_d,
centroids=centroids, bucket_weights=bucket_weights,
nbits=2, normalize=True,
) # [Nd] fp32; one kernel does decompress + L2-normalize + MaxSim
Benchmarks
1xH100, bf16/fp16, 50-iter median, vs the same op in plain PyTorch:
| Workload | Speedup |
|---|---|
| Reranking / inference vs naive einsum | 7-23x |
Long-context (Ld >= 8k) reranking |
runs; naive OOMs |
| PyLate cached-contrastive MaxSim + backward | up to 13.8x |
PLAID rerank vs fast_plaid.engine.search() |
19-30x |
| Fused D-side head (training) | 1.5-4.6x |
| FP8 MaxSim inference (Hopper) | up to 1.4x |
| End-to-end training of a 149M encoder | 1.00-1.06x (free) |
Full tables and reproduction commands: docs/benchmarks.md.
Choose a kernel
|
Not sure which entry point fits your stack? The docs site ships an interactive decision tree that narrows the public API down to the right function in four questions (stack · phase · layout · workload): 👉 hcompai.github.io/late-interaction-kernels/choose-a-kernel.html |
|
API
| Symbol | What it does |
|---|---|
patch_pylate() / unpatch_pylate() |
One-line PyLate drop-in. LIK_DISABLE=1 kill switch. |
MaxSimScorer(normalize=, backward=) |
Stateless nn.Module, autograd-aware. |
retrieve(Q, D, top_k, chunk=) |
Top-k retrieval, chunked for huge corpora. |
maxsim |
Core MaxSim, dense layout. Autograd-aware; auto-skips argmax save when no input requires grad. |
maxsim_varlen |
Packed (cu_seqlens) layout. Autograd-aware. |
maxsim_padded |
Padded reranking wrapper: packs internally, returns [B, C] fp32. |
Other kernels are in submodules: padded, score_pairs, fused_head, plaid, fp8, experimental, reference. See docs/design.md for details on every kernel, the autograd graph and the backward variants.
🔽 Configuration knobs (env vars + kwargs)
| Knob | Effect |
|---|---|
maxsim(..., backward="auto" | "unified" | "atomic" | "csr") |
Per-call grad_D strategy. "auto" picks per shape. |
patch_colpali_engine() / unpatch_colpali_engine() |
colpali_engine drop-in: loss + scoring routes through the kernel. |
LIK_DISABLE=1 |
Patched entry points delegate to vanilla PyLate. |
LIK_SUPPRESS_NORM_WARN=1 |
Silence the "looks unnormalized" one-shot warning. |
LIK_DISABLE_COMPILE=1 |
Skip torch.compile on the MPS path (eager fallback). |
LIK_FORCE_MPS_BACKEND={metal,compile,reference} |
Pin the MPS dispatch. |
Development
git clone https://github.com/hcompai/late-interaction-kernels
cd late-interaction-kernels
uv sync --extra dev --extra pylate --extra torch-cuda # GPU dev; use --extra torch-cpu on CPU-only boxes
uv run pytest -q # CUDA tests auto-skip without a GPU
uv run ruff check . && uv run ruff format --check .
[!NOTE] Pick exactly one of
--extra torch-cuda(pulls torch from the CUDA index —cu124) or--extra torch-cpu(CPU-only wheel, what CI uses). The two are declared as conflicting inpyproject.tomlso the lockfile resolves cleanly for both. On macOS,--extra torch-cpufalls back to PyPI's default (MPS-capable) wheel automatically.
GPU tests run automatically on every push to main. To run them on a PR, apply the run-gpu-tests label.
See CONTRIBUTING.md for the contribution workflow.
Citation
@software{late_interaction_kernels_2026,
author = {Lac, Aurélien and Wu, Tony},
title = {{late-interaction-kernels}: Fused Triton kernels for late-interaction scoring},
year = {2026},
url = {https://github.com/hcompai/late-interaction-kernels},
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file late_interaction_kernels-0.2.0.tar.gz.
File metadata
- Download URL: late_interaction_kernels-0.2.0.tar.gz
- Upload date:
- Size: 562.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6987ee506f2af7312f85868e127865b94d7c4f2847a4de82cea7c95595fd4d4f
|
|
| MD5 |
8818881a7974f65143cbd0401af6416d
|
|
| BLAKE2b-256 |
232849298d2d820c816e3411a0d2d86d7f1cc4e6bdbe7268f2e6203dfe1884c6
|
Provenance
The following attestation bundles were made for late_interaction_kernels-0.2.0.tar.gz:
Publisher:
publish.yml on hcompai/late-interaction-kernels
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
late_interaction_kernels-0.2.0.tar.gz -
Subject digest:
6987ee506f2af7312f85868e127865b94d7c4f2847a4de82cea7c95595fd4d4f - Sigstore transparency entry: 1601964938
- Sigstore integration time:
-
Permalink:
hcompai/late-interaction-kernels@d5aa60a29cd84d3d2049a474bad9a61498b1d502 -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/hcompai
-
Access:
private
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@d5aa60a29cd84d3d2049a474bad9a61498b1d502 -
Trigger Event:
release
-
Statement type:
File details
Details for the file late_interaction_kernels-0.2.0-py3-none-any.whl.
File metadata
- Download URL: late_interaction_kernels-0.2.0-py3-none-any.whl
- Upload date:
- Size: 93.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b0b43ad3eee1e8191e4170a5613fd794ac92af730ed02cc698246d1bd5c1d4c2
|
|
| MD5 |
f113364becdd45a3db76d0f8159a0c2f
|
|
| BLAKE2b-256 |
46e8cd4b7ab2c77ecaa81602b6b7213414a53d296ac71dea9d43fbe6e9148a42
|
Provenance
The following attestation bundles were made for late_interaction_kernels-0.2.0-py3-none-any.whl:
Publisher:
publish.yml on hcompai/late-interaction-kernels
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
late_interaction_kernels-0.2.0-py3-none-any.whl -
Subject digest:
b0b43ad3eee1e8191e4170a5613fd794ac92af730ed02cc698246d1bd5c1d4c2 - Sigstore transparency entry: 1601964946
- Sigstore integration time:
-
Permalink:
hcompai/late-interaction-kernels@d5aa60a29cd84d3d2049a474bad9a61498b1d502 -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/hcompai
-
Access:
private
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@d5aa60a29cd84d3d2049a474bad9a61498b1d502 -
Trigger Event:
release
-
Statement type: