Skip to main content

Fused Triton kernels for late-interaction (MaxSim) scoring — ColBERT, ColPali, ModernColBERT.

Project description

⚡ late-interaction-kernels

Fused Triton kernels for MaxSim scoring. ColBERT · ColPali · ModernColBERT · LateOn · LateOn-Code · ColBERTv2 · PyLate-native.

CI License Python PyTorch Triton PyLate

Install · Quickstart · Speedups · API · Benchmarks · Design · Models


Drop-in, numerically-identical replacements for the MaxSim math behind late-interaction training, reranking and retrieval. One line patches PyLate; the rest is nn.Module and function-level APIs for custom pipelines.

Not a search engine. For end-to-end retrieval use FastPlaid, NextPlaid / ColGrep, or PyLate. This library is the MaxSim math their reranking / training compiles down to.


Install

pip install late-interaction-kernels
Platform Path
Linux + CUDA (sm_75+) Fused Triton kernels — full speedups in Speedups.
macOS (Apple Silicon, MPS) Fused Metal simdgroup_matrix kernel for inference, torch.compile for training.
CPU / Windows / anything else Eager pure-PyTorch reference, autograd-aware.

MaxSimScorer, retrieve and late_interaction_kernels.reference import and run on every platform, so training and retrieval code is unit-testable on a laptop before renting a GPU. The PyLate drop-in targets PyLate ≥ 1.3.


Quickstart

Speed up PyLate, one line:

from late_interaction_kernels import patch_pylate

patch_pylate()
# PyLate training / rerank code is unchanged

LIK_DISABLE=1 in the environment makes patched entry points fall back to vanilla PyLate at runtime.

Score MaxSim in any training loop:

from late_interaction_kernels import MaxSimScorer

scorer = MaxSimScorer(normalize=True)                # nn.Module, no parameters
scores = scorer(Q, D, q_mask=q_mask, d_mask=d_mask)  # [Nq, Nd] fp32
scores.mean().backward()

Top-k retrieval:

from late_interaction_kernels import retrieve

scores, indices = retrieve(Q, D, top_k=100, chunk=4096)
# both [Nq, 100] — chunk= bounds peak HBM at Nq · (chunk + top_k)

PLAID / ColBERTv2 rerank on compressed, ragged docs:

from late_interaction_kernels import maxsim_residual_varlen

# fast-plaid / ColBERTv2 on-disk layout
scores = maxsim_residual_varlen(
    Q, codes_flat, residuals_flat, cu_seqlens_d,
    centroids=centroids, bucket_weights=bucket_weights,
    nbits=2, normalize=True,
)  # [Nd] fp32 — one kernel does decompress + L2-normalize + MaxSim

Speedups on H100

1×H100 80 GB SXM, bf16 / fp16 compute, fp32 accumulator, 50-iter median. Every baseline is the same operation in plain PyTorch.

Workload Speedup
Reranking / inference vs naive einsum 7–23×
Long-context (Ld ≥ 8k) reranking runs; naive OOMs
PyLate cached-contrastive MaxSim + backward up to 13.8×
PLAID rerank vs fast_plaid.engine.search() 19–30×
Fused D-side head (training) 1.5–4.6×
FP8 MaxSim inference (Hopper) up to 1.4×
LateOn-Code-edge end-to-end training (17 M) 1.04–1.27×
LateOn / ModernColBERT end-to-end training (149 M) 1.00–1.06× (free)

ModernBERT-class encoders dominate step time, so on a full 149 M training run the kernel is essentially a free swap. Wherever MaxSim stops being negligible — inference, reranking, long docs, big effective batch, KD, compressed indices, small encoders — the fused path moves.

Full tables, shapes and reproduction commands: docs/benchmarks.md.


API

Most users only need patch_pylate(), MaxSimScorer or retrieve.

Symbol What it does
patch_pylate() / unpatch_pylate() One-line PyLate drop-in. LIK_DISABLE=1 kill switch.
MaxSimScorer(normalize=, backward=) Stateless nn.Module, autograd-aware.
retrieve(Q, D, top_k, chunk=) Top-k retrieval, chunked for huge corpora.
maxsim / maxsim_inference Core MaxSim, dense layout (autograd / forward-only).
maxsim_varlen Packed (cu_seqlens) layout. Autograd-aware.
maxsim_inference_scatter Pair-list reranking on packed batches (vLLM-style scheduling).
maxsim_from_hidden(_train) Fused D-side Linear → Normalize → MaxSim, no [Nd, Ld, d_out] scratch.
maxsim_residual(_varlen) Fused PLAID / ColBERTv2 decompress + normalize + MaxSim.
plaid_approx_score IVF prune step for ColBERTv2.
maxsim_inference_fp8 FP8 tensor-core MaxSim (Hopper / Blackwell). Auto-fallback to bf16.

Submodules:

  • late_interaction_kernels.fp8 — FP8 quantize / dequantize helpers (per-tensor and per-token).
  • late_interaction_kernels.experimental — research kernels (soft_maxsim, smooth_maxsim, maxsim_matryoshka, maxsim_xtr).
  • late_interaction_kernels.reference — pure-PyTorch reference implementations, importable on any platform.

Config:

Knob Effect
maxsim(..., backward="auto" | "unified" | "atomic" | "csr") Per-call grad_D strategy. "auto" picks per shape.
set_backward_method(...) / get_backward_method() Process-wide default (back-compat; prefer per-call kwarg).
LIK_DISABLE=1 Patched entry points delegate to vanilla PyLate.
LIK_SUPPRESS_NORM_WARN=1 Silence the "looks unnormalized" one-shot warning.
LIK_DISABLE_COMPILE=1 Skip torch.compile on the MPS path (eager fallback).
LIK_FORCE_MPS_BACKEND={metal,compile,reference} Pin the MPS dispatch (default: heuristic on shape).

Walk-through of every kernel, the autograd graph, the backward variants and the numerics: docs/design.md.


Hardware

Primary target: H100 / H200 (autotuned, FP8 WGMMA, warp-specialized on Triton ≥ 3.2). Also tuned for A100, Ada (L4 / L40 / 4090) and Ampere (A10 / A40 / 3090). Older / unknown CUDA falls back to a conservative shortlist.

Apple Silicon (MPS) ships two paths and picks per call:

  • a fused Metal simdgroup_matrix kernel (forward-only) — 1.9–3.2× faster than plain PyTorch (1.1–2.0× over torch.compile) on realistic inference shapes, with ~300× less peak memory on big corpora because it never materialises [Nq · Nd · Lq · Ld]. Persistent threadgroups serve 8 consecutive js per launch and keep Q register-resident across every (j, d-chunk);
  • a torch.compile-fused reference (autograd-aware) — carries every training-time call and small-batch inference where the Metal kernel's launch overhead doesn't amortise (still 1.4× over eager).

See docs/benchmarks.md for shapes and numbers.

Autotune runs once per unique (Lq, Ld, d, masks) signature on CUDA and caches the winner; the MPS compile cache keys on (dtype, normalize, has_q_mask, has_d_mask) and amortises after the first call.


Development

git clone https://github.com/hcompai/late-interaction-kernels
cd late-interaction-kernels
pip install -e ".[dev,pylate]"

pytest -q                               # auto-skips CUDA tests without a GPU
ruff check . && ruff format --check .

python benchmarks/bench_forward.py      # see benchmarks/ for the full set

CONTRIBUTING.md for the contribution workflow. CHANGELOG.md for the kernel-by-kernel release history.


Citation

@software{late_interaction_kernels_2026,
  author  = {Lac, Aurélien and Wu, Tony},
  title   = {{late-interaction-kernels}: Fused Triton kernels for late-interaction scoring},
  year    = {2026},
  url     = {https://github.com/hcompai/late-interaction-kernels},
}

Aurélien Lac · Tony Wu — H Company · 2026

License

Apache 2.0 — see LICENSE. Copyright 2026 Aurélien Lac and Tony Wu.

Related

PyLate · FastPlaid · NextPlaid / ColGrep · flash-maxsim (IBM, the first public Triton MaxSim — direct inspiration) · FlashAttention.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

late_interaction_kernels-0.1.0.tar.gz (373.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

late_interaction_kernels-0.1.0-py3-none-any.whl (85.1 kB view details)

Uploaded Python 3

File details

Details for the file late_interaction_kernels-0.1.0.tar.gz.

File metadata

  • Download URL: late_interaction_kernels-0.1.0.tar.gz
  • Upload date:
  • Size: 373.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for late_interaction_kernels-0.1.0.tar.gz
Algorithm Hash digest
SHA256 6d6bfbcf5a1afa3d4aa13b8ddae4bfd4b6bf9cb963b4b23556bf50fe7f744072
MD5 2bf36012bfd8c808b03d760d9ff22957
BLAKE2b-256 608b006172b59f9d89ecb8a0e0d8e35dd92a17d0ddaa5646a58b5b0e45ad51fa

See more details on using hashes here.

Provenance

The following attestation bundles were made for late_interaction_kernels-0.1.0.tar.gz:

Publisher: publish.yml on hcompai/late-interaction-kernels

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file late_interaction_kernels-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for late_interaction_kernels-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f75e0423e7a91e8e59d3c15c83eff62b74f3423d05825a178252121bf15a50ee
MD5 0f8c5cf0bbb51091a3616fc77d9f1a67
BLAKE2b-256 cadfc056c1437bcb55bdeff2d0504213dcab34aaccaa7b2b6c0f9d8ab62cb1a0

See more details on using hashes here.

Provenance

The following attestation bundles were made for late_interaction_kernels-0.1.0-py3-none-any.whl:

Publisher: publish.yml on hcompai/late-interaction-kernels

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page