Skip to main content

A library for training and retrieval with ColBERT.

Project description

PyLate

Flexible Training and Retrieval for Late Interaction Models

documentation license

 

PyLate is a library built on top of Sentence Transformers, designed to simplify and optimize fine-tuning, inference, and retrieval with state-of-the-art ColBERT models. It enables easy fine-tuning on both single and multiple GPUs, providing flexibility for various hardware setups. PyLate also streamlines document retrieval and allows you to load a wide range of models, enabling you to construct ColBERT models from most pre-trained language models.

 

Installation

You can install PyLate using pip:

pip install pylate

For evaluation dependencies, use:

pip install "pylate[eval]"

Documentation

The complete documentation is available here, which includes in-depth guides, examples, and API references.

 

Training

Contrastive Training

Here’s a simple example of training a ColBERT model on the MS MARCO dataset triplet dataset using PyLate. This script demonstrates training with contrastive loss and evaluating the model on a held-out eval set:

import torch
from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)

from pylate import evaluation, losses, models, utils

# Define model parameters for contrastive training
model_name = "bert-base-uncased"  # Choose the pre-trained model you want to use as base
batch_size = 32  # Larger batch size often improves results, but requires more memory

num_train_epochs = 1  # Adjust based on your requirements
# Set the run name for logging and output directory
run_name = "contrastive-bert-base-uncased"
output_dir = f"output/{run_name}"

# 1. Here we define our ColBERT model. If not a ColBERT model, will add a linear layer to the base encoder.
model = models.ColBERT(model_name_or_path=model_name)

# Compiling the model makes the training faster
model = torch.compile(model)

# Load dataset
dataset = load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train")
# Split the dataset (this dataset does not have a validation set, so we split the training set)
splits = dataset.train_test_split(test_size=0.01)
train_dataset = splits["train"]
eval_dataset = splits["test"]

# Define the loss function
train_loss = losses.Contrastive(model=model)

# Initialize the evaluator
dev_evaluator = evaluation.ColBERTTripletEvaluator(
    anchors=eval_dataset["query"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative"],
)

# Configure the training arguments (e.g., batch size, evaluation strategy, logging steps)
args = SentenceTransformerTrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
    learning_rate=3e-6,
)

# Initialize the trainer for the contrastive training
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=train_loss,
    evaluator=dev_evaluator,
    data_collator=utils.ColBERTCollator(model.tokenize),
)
# Start the training process
trainer.train()

After training, the model can be loaded using the output directory path:

from pylate import models

model = models.ColBERT(model_name_or_path="contrastive-bert-base-uncased")

Please note that temperature parameter has a very high importance in contrastive learning, and a temperature around 0.02 is often used in the literature:

train_loss = losses.Contrastive(model=model, temperature=0.02)

As contrastive learning is not compatible with gradient accumulation, you can leverage GradCache to emulate bigger batch sizes without requiring more memory by using the CachedContrastiveLoss to define a mini_batch_size while increasing the per_device_train_batch_size:

train_loss = losses.CachedContrastive(
        model=model, mini_batch_size=mini_batch_size
)

Finally, if you are in a multi-GPU setting, you can gather all the elements from the different GPUs to create even bigger batch sizes by setting gather_across_devices to True (for both Contrastive and CachedContrastive losses):

train_loss = losses.Contrastive(model=model, gather_across_devices=True)

 

Knowledge Distillation

To get the best performance when training a ColBERT model, you should use knowledge distillation to train the model using the scores of a strong teacher model. Here's a simple example of how to train a model using knowledge distillation in PyLate on MS MARCO:

import torch
from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)

from pylate import losses, models, utils

# Load the datasets required for knowledge distillation (train, queries, documents)
train = load_dataset(
    path="lightonai/ms-marco-en-bge",
    name="train",
)

queries = load_dataset(
    path="lightonai/ms-marco-en-bge",
    name="queries",
)

documents = load_dataset(
    path="lightonai/ms-marco-en-bge",
    name="documents",
)

# Set the transformation to load the documents/queries texts using the corresponding ids on the fly
train.set_transform(
    utils.KDProcessing(queries=queries, documents=documents).transform,
)

# Define the base model, training parameters, and output directory
model_name = "bert-base-uncased"  # Choose the pre-trained model you want to use as base
batch_size = 16
num_train_epochs = 1
# Set the run name for logging and output directory
run_name = "knowledge-distillation-bert-base"
output_dir = f"output/{run_name}"

# Initialize the ColBERT model from the base model
model = models.ColBERT(model_name_or_path=model_name)

# Compiling the model to make the training faster
model = torch.compile(model)

# Configure the training arguments (e.g., epochs, batch size, learning rate)
args = SentenceTransformerTrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    run_name=run_name,
    learning_rate=1e-5,
)

# Use the Distillation loss function for training
train_loss = losses.Distillation(model=model)

# Initialize the trainer
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train,
    loss=train_loss,
    data_collator=utils.ColBERTCollator(tokenize_fn=model.tokenize),
)

# Start the training process
trainer.train()

NanoBEIR evaluator

If you are training an English retrieval model, you can use NanoBEIR evaluator, which allows to run small version of BEIR to get quick validation results.

evaluator=evaluation.NanoBEIREvaluator(),

 

Datasets

PyLate supports Hugging Face Datasets, enabling seamless triplet / knowledge distillation based training. For contrastive training, you can use any of the existing sentence transformers triplet datasets. Below is an example of creating a custom triplet dataset for training:

from datasets import Dataset

dataset = [
    {
        "query": "example query 1",
        "positive": "example positive document 1",
        "negative": "example negative document 1",
    },
    {
        "query": "example query 2",
        "positive": "example positive document 2",
        "negative": "example negative document 2",
    },
    {
        "query": "example query 3",
        "positive": "example positive document 3",
        "negative": "example negative document 3",
    },
]

dataset = Dataset.from_list(mapping=dataset)

train_dataset, test_dataset = dataset.train_test_split(test_size=0.3)

Note that PyLate supports more than one negative per query, simply add the additional negatives after the first one in the row.

{
        "query": "example query 1",
        "positive": "example positive document 1",
        "negative_1": "example negative document 1",
        "negative_2": "example negative document 2",
}

To create a knowledge distillation dataset, you can use the following snippet:

from datasets import Dataset

dataset = [
    {
        "query_id": 54528,
        "document_ids": [
            6862419,
            335116,
            339186,
        ],
        "scores": [
            0.4546215673141326,
            0.6575686537173476,
            0.26825184192900203,
        ],
    },
    {
        "query_id": 749480,
        "document_ids": [
            6862419,
            335116,
            339186,
        ],
        "scores": [
            0.2546215673141326,
            0.7575686537173476,
            0.96825184192900203,
        ],
    },
]


dataset = Dataset.from_list(mapping=dataset)

documents = [
    {"document_id": 6862419, "text": "example doc 1"},
    {"document_id": 335116, "text": "example doc 2"},
    {"document_id": 339186, "text": "example doc 3"},
]

queries = [
    {"query_id": 749480, "text": "example query"},
]

documents = Dataset.from_list(mapping=documents)

queries = Dataset.from_list(mapping=queries)

 

Retrieval

PyLate provides an efficient index with FastPLAID. Simply load a ColBERT model and initialize the index to perform retrieval.

from pylate import indexes, models, retrieve

model = models.ColBERT(
    model_name_or_path="lightonai/GTE-ModernColBERT-v1",
)

index = indexes.PLAID(
    index_folder="pylate-index",
    index_name="index",
    override=True,
)

retriever = retrieve.ColBERT(index=index)

Once the model and index are set up, we can add documents to the index using their embeddings and corresponding ids:

documents_ids = ["1", "2", "3"]

documents = [
    "ColBERT’s late-interaction keeps token-level embeddings to deliver cross-encoder-quality ranking at near-bi-encoder speed, enabling fine-grained relevance, robustness across domains, and hardware-friendly scalable search.",

    "PLAID compresses ColBERT token vectors via product quantization to shrink storage by 10×, uses two-stage centroid scoring for sub-200 ms latency, and plugs directly into existing ColBERT pipelines.",

    "PyLate is a library built on top of Sentence Transformers, designed to simplify and optimize fine-tuning, inference, and retrieval with state-of-the-art ColBERT models. It enables easy fine-tuning on both single and multiple GPUs, providing flexibility for various hardware setups. PyLate also streamlines document retrieval and allows you to load a wide range of models, enabling you to construct ColBERT models from most pre-trained language models.",
]

# Encode the documents
documents_embeddings = model.encode(
    documents,
    batch_size=32,
    is_query=False, # Encoding documents
    show_progress_bar=True,
)

# Add the documents ids and embeddings to the PLAID index
index.add_documents(
    documents_ids=documents_ids,
    documents_embeddings=documents_embeddings,
)

Then we can retrieve the top-k documents for a given set of queries:

queries_embeddings = model.encode(
    ["query for document 3", "query for document 1"],
    batch_size=32,
    is_query=True, # Encoding queries
    show_progress_bar=True,
)

scores = retriever.retrieve(
    queries_embeddings=queries_embeddings,
    k=10,
)

print(scores)

Sample Output:

[
    [
        {"id": "3", "score": 11.266985893249512},
        {"id": "1", "score": 10.303335189819336},
        {"id": "2", "score": 9.502392768859863},
    ],
    [
        {"id": "1", "score": 10.88800048828125},
        {"id": "3", "score": 9.950843811035156},
        {"id": "2", "score": 9.602447509765625},
    ],
]

 

Reranking

If you want to use the ColBERT model to perform reranking on top of your first-stage retrieval pipeline without building an index, you can simply use rank.rerank function which takes the queries and documents embeddings along with the documents ids to rerank them:

from pylate import rank

queries = [
    "query A",
    "query B",
]

documents = [
    ["document A", "document B"],
    ["document 1", "document C", "document B"],
]

documents_ids = [
    [1, 2],
    [1, 3, 2],
]

queries_embeddings = model.encode(
    queries,
    is_query=True,
)

documents_embeddings = model.encode(
    documents,
    is_query=False,
)

reranked_documents = rank.rerank(
    documents_ids=documents_ids,
    queries_embeddings=queries_embeddings,
    documents_embeddings=documents_embeddings,
)

 

Contributing

We welcome contributions! To get started:

  1. Install the development dependencies:
pip install "pylate[dev]"
  1. Run tests:
make test
  1. Format code with Ruff:
make lint

Citation

You can refer to the library with this BibTeX:

@inproceedings{DBLP:conf/cikm/ChaffinS25,
  author       = {Antoine Chaffin and
                  Rapha{\"{e}}l Sourty},
  editor       = {Meeyoung Cha and
                  Chanyoung Park and
                  Noseong Park and
                  Carl Yang and
                  Senjuti Basu Roy and
                  Jessie Li and
                  Jaap Kamps and
                  Kijung Shin and
                  Bryan Hooi and
                  Lifang He},
  title        = {PyLate: Flexible Training and Retrieval for Late Interaction Models},
  booktitle    = {Proceedings of the 34th {ACM} International Conference on Information
                  and Knowledge Management, {CIKM} 2025, Seoul, Republic of Korea, November
                  10-14, 2025},
  pages        = {6334--6339},
  publisher    = {{ACM}},
  year         = {2025},
  url          = {https://github.com/lightonai/pylate},
  doi          = {10.1145/3746252.3761608},
}

DeepWiki

PyLate is indexed on DeepWiki so you can ask questions to LLMs using Deep Research to explore the codebase and get help to add new features.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pylate-1.4.0.tar.gz (689.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pylate-1.4.0-py3-none-any.whl (149.4 kB view details)

Uploaded Python 3

File details

Details for the file pylate-1.4.0.tar.gz.

File metadata

  • Download URL: pylate-1.4.0.tar.gz
  • Upload date:
  • Size: 689.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pylate-1.4.0.tar.gz
Algorithm Hash digest
SHA256 84195153f7c99bdb57dd9d773cfdd0a1fe9dfab0b06da0d8dd79621f7c0d550d
MD5 23e2efc671b79c2e412b25bf071a221e
BLAKE2b-256 ca8f5237e95fe5257db28cc4c08c1fe213648049a58bbd83f71f473dd99675fe

See more details on using hashes here.

Provenance

The following attestation bundles were made for pylate-1.4.0.tar.gz:

Publisher: publish.yml on lightonai/pylate

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pylate-1.4.0-py3-none-any.whl.

File metadata

  • Download URL: pylate-1.4.0-py3-none-any.whl
  • Upload date:
  • Size: 149.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pylate-1.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d633c7fcc058b7786e11e14d303fe79e91e1b4a5e263a1e619e6fe6f77a27c2b
MD5 7c384a725073d0f6355c80b1b5c62bf6
BLAKE2b-256 1fd8020533bb8a778212aaa28b5018700ff361f0d2d4f42c34e22c4c6e249b27

See more details on using hashes here.

Provenance

The following attestation bundles were made for pylate-1.4.0-py3-none-any.whl:

Publisher: publish.yml on lightonai/pylate

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page