Skip to main content

Fused Triton kernels for late-interaction (MaxSim) scoring — ColBERT, ColPali, ModernColBERT.

Project description

[!NOTE] Full algorithmic walkthrough, animations and benchmark plots live on the docs site: hcompai.github.io/late-interaction-kernels.

Introduction

late-interaction-kernels provides fused Triton kernels for MaxSim, the late-interaction scoring used by ColBERT, ColPali, ModernColBERT, LateOn and ColBERTv2. The kernels are numerically identical to plain PyTorch and come with three APIs:

  • a one-line PyLate drop-in (patch_pylate()),
  • a stateless nn.Module (MaxSimScorer) for custom training loops,
  • function-level entry points (maxsim, maxsim_varlen, maxsim_padded, ...) for everything else.

This is not a search engine. For end-to-end training or retrieval use PyLate, FastPlaid or NextPlaid. This library is the MaxSim math they compile down to.

Install

pip install late-interaction-kernels
Platform Backend
Linux + CUDA (sm_75+) Fused Triton kernels (autotuned, FP8 on Hopper).
macOS (Apple Silicon, MPS) Fused Metal simdgroup_matrix for inference, torch.compile for training.
CPU / Windows Autograd-aware pure-PyTorch reference.

[!NOTE] The PyLate drop-in targets PyLate >= 1.3. The pure-PyTorch reference imports on every platform, so training and retrieval code is unit-testable on a laptop before you rent a GPU.

Quickstart

Patch PyLate (one line)

from late_interaction_kernels import patch_pylate

patch_pylate()
# PyLate training / rerank code is unchanged

Set LIK_DISABLE=1 in the environment to fall back to vanilla PyLate at runtime.

Custom training loop

from late_interaction_kernels import MaxSimScorer

scorer = MaxSimScorer(normalize=True)                # nn.Module, no parameters
scores = scorer(Q, D, q_mask=q_mask, d_mask=d_mask)  # [Nq, Nd] fp32
scores.mean().backward()

Top-k retrieval

from late_interaction_kernels import retrieve

scores, indices = retrieve(Q, D, top_k=100, chunk=4096)
# both [Nq, 100]; chunk= bounds peak HBM at Nq * (chunk + top_k)

PLAID / ColBERTv2 on compressed, ragged docs

from late_interaction_kernels.plaid import maxsim_residual_varlen

scores = maxsim_residual_varlen(
    Q, codes_flat, residuals_flat, cu_seqlens_d,
    centroids=centroids, bucket_weights=bucket_weights,
    nbits=2, normalize=True,
)  # [Nd] fp32; one kernel does decompress + L2-normalize + MaxSim

Benchmarks

1xH100, bf16/fp16, 50-iter median, vs the same op in plain PyTorch:

Workload Speedup
Reranking / inference vs naive einsum 7-23x
Long-context (Ld >= 8k) reranking runs; naive OOMs
PyLate cached-contrastive MaxSim + backward up to 13.8x
PLAID rerank vs fast_plaid.engine.search() 19-30x
Fused D-side head (training) 1.5-4.6x
FP8 MaxSim inference (Hopper) up to 1.4x
End-to-end training of a 149M encoder 1.00-1.06x (free)

Full tables and reproduction commands: docs/benchmarks.md.

Choose a kernel

Not sure which entry point fits your stack? The docs site ships an interactive decision tree that narrows the public API down to the right function in four questions (stack · phase · layout · workload):

👉 hcompai.github.io/late-interaction-kernels/choose-a-kernel.html

Pick a kernel · interactive decision tree

API

Symbol What it does
patch_pylate() / unpatch_pylate() One-line PyLate drop-in. LIK_DISABLE=1 kill switch.
MaxSimScorer(normalize=, backward=) Stateless nn.Module, autograd-aware.
retrieve(Q, D, top_k, chunk=) Top-k retrieval, chunked for huge corpora.
maxsim Core MaxSim, dense layout. Autograd-aware; auto-skips argmax save when no input requires grad.
maxsim_varlen Packed (cu_seqlens) layout. Autograd-aware.
maxsim_padded Padded reranking wrapper: packs internally, returns [B, C] fp32.

Other kernels are in submodules: padded, score_pairs, fused_head, plaid, fp8, experimental, reference. See docs/design.md for details on every kernel, the autograd graph and the backward variants.

🔽 Configuration knobs (env vars + kwargs)
Knob Effect
maxsim(..., backward="auto" | "unified" | "atomic" | "csr") Per-call grad_D strategy. "auto" picks per shape.
patch_colpali_engine() / unpatch_colpali_engine() colpali_engine drop-in: loss + scoring routes through the kernel.
LIK_DISABLE=1 Patched entry points delegate to vanilla PyLate.
LIK_SUPPRESS_NORM_WARN=1 Silence the "looks unnormalized" one-shot warning.
LIK_DISABLE_COMPILE=1 Skip torch.compile on the MPS path (eager fallback).
LIK_FORCE_MPS_BACKEND={metal,compile,reference} Pin the MPS dispatch.

Development

git clone https://github.com/hcompai/late-interaction-kernels
cd late-interaction-kernels
uv sync --extra dev --extra pylate --extra torch-cuda   # GPU dev; use --extra torch-cpu on CPU-only boxes
uv run pytest -q                                        # CUDA tests auto-skip without a GPU
uv run ruff check . && uv run ruff format --check .

[!NOTE] Pick exactly one of --extra torch-cuda (pulls torch from the CUDA index — cu124) or --extra torch-cpu (CPU-only wheel, what CI uses). The two are declared as conflicting in pyproject.toml so the lockfile resolves cleanly for both. On macOS, --extra torch-cpu falls back to PyPI's default (MPS-capable) wheel automatically.

GPU tests run automatically on every push to main. To run them on a PR, apply the run-gpu-tests label.

See CONTRIBUTING.md for the contribution workflow.

Citation

@software{late_interaction_kernels_2026,
  author  = {Lac, Aurélien and Wu, Tony},
  title   = {{late-interaction-kernels}: Fused Triton kernels for late-interaction scoring},
  year    = {2026},
  url     = {https://github.com/hcompai/late-interaction-kernels},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

late_interaction_kernels-0.2.0.tar.gz (562.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

late_interaction_kernels-0.2.0-py3-none-any.whl (93.5 kB view details)

Uploaded Python 3

File details

Details for the file late_interaction_kernels-0.2.0.tar.gz.

File metadata

  • Download URL: late_interaction_kernels-0.2.0.tar.gz
  • Upload date:
  • Size: 562.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for late_interaction_kernels-0.2.0.tar.gz
Algorithm Hash digest
SHA256 6987ee506f2af7312f85868e127865b94d7c4f2847a4de82cea7c95595fd4d4f
MD5 8818881a7974f65143cbd0401af6416d
BLAKE2b-256 232849298d2d820c816e3411a0d2d86d7f1cc4e6bdbe7268f2e6203dfe1884c6

See more details on using hashes here.

Provenance

The following attestation bundles were made for late_interaction_kernels-0.2.0.tar.gz:

Publisher: publish.yml on hcompai/late-interaction-kernels

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file late_interaction_kernels-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for late_interaction_kernels-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b0b43ad3eee1e8191e4170a5613fd794ac92af730ed02cc698246d1bd5c1d4c2
MD5 f113364becdd45a3db76d0f8159a0c2f
BLAKE2b-256 46e8cd4b7ab2c77ecaa81602b6b7213414a53d296ac71dea9d43fbe6e9148a42

See more details on using hashes here.

Provenance

The following attestation bundles were made for late_interaction_kernels-0.2.0-py3-none-any.whl:

Publisher: publish.yml on hcompai/late-interaction-kernels

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page