Skip to main content

Fused Triton kernels for late-interaction (MaxSim) scoring — ColBERT, ColPali, ModernColBERT.

Project description

⚡ late-interaction-kernels

Fused Triton kernels for MaxSim scoring. ColBERT · ColPali · ModernColBERT · LateOn · LateOn-Code · ColBERTv2 · PyLate-native.

CI License Python PyTorch Triton PyLate

Install · Quickstart · Speedups · API · Benchmarks · Design · Models


Drop-in, numerically-identical replacements for the MaxSim math behind late-interaction training, reranking and retrieval. One line patches PyLate; the rest is nn.Module and function-level APIs for custom pipelines.

Not a search engine. For end-to-end retrieval use FastPlaid, NextPlaid / ColGrep, or PyLate. This library is the MaxSim math their reranking / training compiles down to.


Install

pip install late-interaction-kernels

Linux + CUDA for the fused kernels. macOS / Windows / CPU still works — MaxSimScorer, retrieve and late_interaction_kernels.reference fall back to a pure-PyTorch implementation, so training and retrieval code runs locally before renting a GPU. The PyLate drop-in targets PyLate ≥ 1.3.


Quickstart

Speed up PyLate, one line:

from late_interaction_kernels import patch_pylate

patch_pylate()
# PyLate training / rerank code is unchanged

LIK_DISABLE=1 in the environment makes patched entry points fall back to vanilla PyLate at runtime.

Score MaxSim in any training loop:

from late_interaction_kernels import MaxSimScorer

scorer = MaxSimScorer(normalize=True)                # nn.Module, no parameters
scores = scorer(Q, D, q_mask=q_mask, d_mask=d_mask)  # [Nq, Nd] fp32
scores.mean().backward()

Top-k retrieval:

from late_interaction_kernels import retrieve

scores, indices = retrieve(Q, D, top_k=100, chunk=4096)
# both [Nq, 100] — chunk= bounds peak HBM at Nq · (chunk + top_k)

PLAID / ColBERTv2 rerank on compressed, ragged docs:

from late_interaction_kernels import maxsim_residual_varlen

# fast-plaid / ColBERTv2 on-disk layout
scores = maxsim_residual_varlen(
    Q, codes_flat, residuals_flat, cu_seqlens_d,
    centroids=centroids, bucket_weights=bucket_weights,
    nbits=2, normalize=True,
)  # [Nd] fp32 — one kernel does decompress + L2-normalize + MaxSim

Speedups on H100

1×H100 80 GB SXM, bf16 / fp16 compute, fp32 accumulator, 50-iter median. Every baseline is the same operation in plain PyTorch.

Workload Speedup
Reranking / inference vs naive einsum 7–23×
Long-context (Ld ≥ 8k) reranking runs; naive OOMs
PyLate cached-contrastive MaxSim + backward up to 13.8×
PLAID rerank vs fast_plaid.engine.search() 19–30×
Fused D-side head (training) 1.5–4.6×
FP8 MaxSim inference (Hopper) up to 1.4×
LateOn-Code-edge end-to-end training (17 M) 1.04–1.27×
LateOn / ModernColBERT end-to-end training (149 M) 1.00–1.06× (free)

ModernBERT-class encoders dominate step time, so on a full 149 M training run the kernel is essentially a free swap. Wherever MaxSim stops being negligible — inference, reranking, long docs, big effective batch, KD, compressed indices, small encoders — the fused path moves.

Full tables, shapes and reproduction commands: docs/benchmarks.md.


API

Most users only need patch_pylate(), MaxSimScorer or retrieve.

Symbol What it does
patch_pylate() / unpatch_pylate() One-line PyLate drop-in. LIK_DISABLE=1 kill switch.
MaxSimScorer(normalize=, backward=) Stateless nn.Module, autograd-aware.
retrieve(Q, D, top_k, chunk=) Top-k retrieval, chunked for huge corpora.
maxsim / maxsim_inference Core MaxSim, dense layout (autograd / forward-only).
maxsim_varlen Packed (cu_seqlens) layout. Autograd-aware.
maxsim_inference_scatter Pair-list reranking on packed batches (vLLM-style scheduling).
maxsim_from_hidden(_train) Fused D-side Linear → Normalize → MaxSim, no [Nd, Ld, d_out] scratch.
maxsim_residual(_varlen) Fused PLAID / ColBERTv2 decompress + normalize + MaxSim.
plaid_approx_score IVF prune step for ColBERTv2.
maxsim_inference_fp8 FP8 tensor-core MaxSim (Hopper / Blackwell). Auto-fallback to bf16.

Submodules:

  • late_interaction_kernels.fp8 — FP8 quantize / dequantize helpers (per-tensor and per-token).
  • late_interaction_kernels.experimental — research kernels (soft_maxsim, smooth_maxsim, maxsim_matryoshka, maxsim_xtr).
  • late_interaction_kernels.reference — pure-PyTorch reference implementations, importable on any platform.

Config:

Knob Effect
maxsim(..., backward="auto" | "unified" | "atomic" | "csr") Per-call grad_D strategy. "auto" picks per shape.
set_backward_method(...) / get_backward_method() Process-wide default (back-compat; prefer per-call kwarg).
LIK_DISABLE=1 Patched entry points delegate to vanilla PyLate.
LIK_SUPPRESS_NORM_WARN=1 Silence the "looks unnormalized" one-shot warning.

Walk-through of every kernel, the autograd graph, the backward variants and the numerics: docs/design.md.


Hardware

Primary target: H100 / H200 (autotuned, FP8 WGMMA, warp-specialized on Triton ≥ 3.2). Also tuned for A100, Ada (L4 / L40 / 4090) and Ampere (A10 / A40 / 3090). Older / unknown CUDA falls back to a conservative shortlist. CPU / macOS / Windows get the pure-PyTorch reference.

Autotune runs once per unique (Lq, Ld, d, masks) signature and caches the winner — zero overhead after warmup.


Development

git clone https://github.com/hcompai/late-interaction-kernels
cd late-interaction-kernels
pip install -e ".[dev,pylate]"

pytest -q                               # auto-skips CUDA tests without a GPU
ruff check . && ruff format --check .

python benchmarks/bench_forward.py      # see benchmarks/ for the full set

CONTRIBUTING.md for the contribution workflow. CHANGELOG.md for the kernel-by-kernel release history.


Citation

@software{late_interaction_kernels_2026,
  author  = {Lac, Aurélien and Wu, Tony},
  title   = {{late-interaction-kernels}: Fused Triton kernels for late-interaction scoring},
  year    = {2026},
  url     = {https://github.com/hcompai/late-interaction-kernels},
}

Aurélien Lac · Tony Wu — H Company · 2026

License

Apache 2.0 — see LICENSE. Copyright 2026 Aurélien Lac and Tony Wu.

Related

PyLate · FastPlaid · NextPlaid / ColGrep · flash-maxsim (IBM, the first public Triton MaxSim — direct inspiration) · FlashAttention.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

late_interaction_kernels-0.0.1.tar.gz (425.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

late_interaction_kernels-0.0.1-py3-none-any.whl (75.1 kB view details)

Uploaded Python 3

File details

Details for the file late_interaction_kernels-0.0.1.tar.gz.

File metadata

  • Download URL: late_interaction_kernels-0.0.1.tar.gz
  • Upload date:
  • Size: 425.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for late_interaction_kernels-0.0.1.tar.gz
Algorithm Hash digest
SHA256 91f3da411f4eba8636dd84dd1a170ff0144972b356cbccd635928101f935b714
MD5 192e9282db5ac001df3d5dc4a10d408f
BLAKE2b-256 60952f90621a9233215feedfc69014b1e99d533b2d0b5adeac0285d59b4b1140

See more details on using hashes here.

Provenance

The following attestation bundles were made for late_interaction_kernels-0.0.1.tar.gz:

Publisher: publish.yml on hcompai/late-interaction-kernels

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file late_interaction_kernels-0.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for late_interaction_kernels-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 3601b2f48e90089ac92920643bf81019eacd71f6dbf0c825720c85afd0c1a046
MD5 a5a41b167ffd92d67834c8bd9d498aa9
BLAKE2b-256 7788ef4be038da927dc1bc9308cff7616b8be03ceb1c9e754eaddc6d8728d66e

See more details on using hashes here.

Provenance

The following attestation bundles were made for late_interaction_kernels-0.0.1-py3-none-any.whl:

Publisher: publish.yml on hcompai/late-interaction-kernels

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page