Skip to main content

No project description provided

Project description

PyLate

Flexible Training and Retrieval for Late Interaction Models

documentation license

PyLate is a library built on top of Sentence Transformers, designed to simplify and optimize fine-tuning, inference, and retrieval with state-of-the-art ColBERT models. It enables easy fine-tuning on both single and multiple GPUs, providing flexibility for various hardware setups. PyLate also streamlines document retrieval and allows you to load a wide range of models, enabling you to construct ColBERT models from most pre-trained language models.

Installation

You can install PyLate using pip:

pip install pylate

For evaluation dependencies, use:

pip install "pylate[eval]"

Documentation

The complete documentation is available here, which includes in-depth guides, examples, and API references.

Datasets

PyLate supports Hugging Face Datasets, enabling seamless triplet / knowledge distillation based training. Below is an example of creating a custom dataset for training:

from datasets import Dataset

dataset = [
    {
        "query": "example query 1",
        "positive": "example positive document 1",
        "negative": "example negative document 1",
    },
    {
        "query": "example query 2",
        "positive": "example positive document 2",
        "negative": "example negative document 2",
    },
    {
        "query": "example query 3",
        "positive": "example positive document 3",
        "negative": "example negative document 3",
    },
]

dataset = Dataset.from_list(mapping=dataset)

train_dataset, test_dataset = dataset.train_test_split(test_size=0.3)

Training

Here’s a simple example of training a ColBERT model on the MSMARCO dataset using PyLate. This script demonstrates training with triplet loss and evaluating the model on a test set.

from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.training_args import BatchSamplers

from pylate import evaluation, losses, models, utils

# Define the model
model = models.ColBERT(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2")

# Load dataset
dataset = load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train")

# Split the dataset to create a test set
train_dataset, eval_dataset = dataset.train_test_split(test_size=0.01)

# Shuffle and select a subset of the dataset for demonstration purposes
MAX_TRAIN_SIZE, MAX_EVAL_SIZE = 100, 100
train_dataset = train_dataset.shuffle(seed=21).select(range(MAX_TRAIN_SIZE))
eval_dataset = eval_dataset.shuffle(seed=21).select(range(MAX_EVAL_SIZE))

# Define the loss function
train_loss = losses.Contrastive(model=model)

args = SentenceTransformerTrainingArguments(
    output_dir="colbert-training",
    num_train_epochs=1,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    fp16=False,  # Some GPUs support FP16 which is faster than FP32
    bf16=False,  # Some GPUs support BF16 which is a faster FP16
    batch_sampler=BatchSamplers.NO_DUPLICATES,
    # Tracking parameters:
    eval_strategy="steps",
    eval_steps=0.1,
    save_strategy="steps",
    save_steps=5000,
    save_total_limit=2,
    learning_rate=3e-6,
)

# Evaluation procedure
dev_evaluator = evaluation.ColBERTTripletEvaluator(
    anchors=eval_dataset["query"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative"],
)

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=train_loss,
    evaluator=dev_evaluator,
    data_collator=utils.ColBERTCollator(tokenize_fn=model.tokenize),
)

trainer.train()

model.save_pretrained("custom-colbert-model")

After training, the model can be loaded like this:

from pylate import models

model = models.ColBERT(model_name_or_path="custom-colbert-model")

Retrieve

PyLate allows easy retrieval of top documents for a given query set using the trained ColBERT model and Voyager index.

from pylate import indexes, models, retrieve

model = models.ColBERT(
    model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
)

index = indexes.Voyager(
    index_folder="pylate-index",
    index_name="index",
    override=True,
)

retriever = retrieve.ColBERT(index=index)

Once the model and index are set up, we can add documents to the index:

documents_ids = ["1", "2", "3"]

documents = [
    "document 1 text", "document 2 text", "document 3 text"
]

# Encode the documents
documents_embeddings = model.encode(
    documents,
    batch_size=32,
    is_query=False, # Encoding documents
    show_progress_bar=True,
)

# Add the documents ids and embeddings to the Voyager index
index.add_documents(
    documents_ids=documents_ids,
    documents_embeddings=documents_embeddings,
)

Then we can retrieve the top-k documents for a given query set:

queries_embeddings = model.encode(
    ["query for document 3", "query for document 1"],
    batch_size=32,
    is_query=True, # Encoding queries
    show_progress_bar=True,
)

scores = retriever.retrieve(
    queries_embeddings=queries_embeddings, 
    k=10,
)

print(scores)

Sample Output:

[
    [
        {"id": "3", "score": 11.266985893249512},
        {"id": "1", "score": 10.303335189819336},
        {"id": "2", "score": 9.502392768859863},
    ],
    [
        {"id": "1", "score": 10.88800048828125},
        {"id": "3", "score": 9.950843811035156},
        {"id": "2", "score": 9.602447509765625},
    ],
]

Contributing

We welcome contributions! To get started:

  1. Install the development dependencies:
pip install "pylate[dev]"
  1. Run tests:
make test
  1. Format code with Ruff:
make ruff
  1. Build the documentation:
make livedoc

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pylate-0.0.1.tar.gz (37.8 kB view details)

Uploaded Source

File details

Details for the file pylate-0.0.1.tar.gz.

File metadata

  • Download URL: pylate-0.0.1.tar.gz
  • Upload date:
  • Size: 37.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.10.14

File hashes

Hashes for pylate-0.0.1.tar.gz
Algorithm Hash digest
SHA256 0f4b8b8b77630681b30e0ef88cf7dd8606d74955a4bfe5a5f1382cfbdf79a0f5
MD5 927c50dde55e96ee3cda2abcd22bceb4
BLAKE2b-256 f3a7d2161eecd95af03eb16981ce81a87bf2a34077cce99d9e9e0093f241bf13

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page