No project description provided
Project description
PyLate
Flexible Training and Retrieval for Late Interaction Models
PyLate is a library built on top of Sentence Transformers, designed to simplify and optimize fine-tuning, inference, and retrieval with state-of-the-art ColBERT models. It enables easy fine-tuning on both single and multiple GPUs, providing flexibility for various hardware setups. PyLate also streamlines document retrieval and allows you to load a wide range of models, enabling you to construct ColBERT models from most pre-trained language models.
Installation
You can install PyLate using pip:
pip install pylate
For evaluation dependencies, use:
pip install "pylate[eval]"
Documentation
The complete documentation is available here, which includes in-depth guides, examples, and API references.
Datasets
PyLate supports Hugging Face Datasets, enabling seamless triplet / knowledge distillation based training. Below is an example of creating a custom dataset for training:
from datasets import Dataset
dataset = [
{
"query": "example query 1",
"positive": "example positive document 1",
"negative": "example negative document 1",
},
{
"query": "example query 2",
"positive": "example positive document 2",
"negative": "example negative document 2",
},
{
"query": "example query 3",
"positive": "example positive document 3",
"negative": "example negative document 3",
},
]
dataset = Dataset.from_list(mapping=dataset)
train_dataset, test_dataset = dataset.train_test_split(test_size=0.3)
Training
Here’s a simple example of training a ColBERT model on the MSMARCO dataset using PyLate. This script demonstrates training with triplet loss and evaluating the model on a test set.
from datasets import load_dataset
from sentence_transformers import (
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.training_args import BatchSamplers
from pylate import evaluation, losses, models, utils
# Define the model
model = models.ColBERT(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2")
# Load dataset
dataset = load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train")
# Split the dataset to create a test set
train_dataset, eval_dataset = dataset.train_test_split(test_size=0.01)
# Shuffle and select a subset of the dataset for demonstration purposes
MAX_TRAIN_SIZE, MAX_EVAL_SIZE = 100, 100
train_dataset = train_dataset.shuffle(seed=21).select(range(MAX_TRAIN_SIZE))
eval_dataset = eval_dataset.shuffle(seed=21).select(range(MAX_EVAL_SIZE))
# Define the loss function
train_loss = losses.Contrastive(model=model)
args = SentenceTransformerTrainingArguments(
output_dir="colbert-training",
num_train_epochs=1,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
fp16=False, # Some GPUs support FP16 which is faster than FP32
bf16=False, # Some GPUs support BF16 which is a faster FP16
batch_sampler=BatchSamplers.NO_DUPLICATES,
# Tracking parameters:
eval_strategy="steps",
eval_steps=0.1,
save_strategy="steps",
save_steps=5000,
save_total_limit=2,
learning_rate=3e-6,
)
# Evaluation procedure
dev_evaluator = evaluation.ColBERTTripletEvaluator(
anchors=eval_dataset["query"],
positives=eval_dataset["positive"],
negatives=eval_dataset["negative"],
)
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=train_loss,
evaluator=dev_evaluator,
data_collator=utils.ColBERTCollator(tokenize_fn=model.tokenize),
)
trainer.train()
model.save_pretrained("custom-colbert-model")
After training, the model can be loaded like this:
from pylate import models
model = models.ColBERT(model_name_or_path="custom-colbert-model")
Retrieve
PyLate allows easy retrieval of top documents for a given query set using the trained ColBERT model and Voyager index.
from pylate import indexes, models, retrieve
model = models.ColBERT(
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
)
index = indexes.Voyager(
index_folder="pylate-index",
index_name="index",
override=True,
)
retriever = retrieve.ColBERT(index=index)
Once the model and index are set up, we can add documents to the index:
documents_ids = ["1", "2", "3"]
documents = [
"document 1 text", "document 2 text", "document 3 text"
]
# Encode the documents
documents_embeddings = model.encode(
documents,
batch_size=32,
is_query=False, # Encoding documents
show_progress_bar=True,
)
# Add the documents ids and embeddings to the Voyager index
index.add_documents(
documents_ids=documents_ids,
documents_embeddings=documents_embeddings,
)
Then we can retrieve the top-k documents for a given query set:
queries_embeddings = model.encode(
["query for document 3", "query for document 1"],
batch_size=32,
is_query=True, # Encoding queries
show_progress_bar=True,
)
scores = retriever.retrieve(
queries_embeddings=queries_embeddings,
k=10,
)
print(scores)
Sample Output:
[
[
{"id": "3", "score": 11.266985893249512},
{"id": "1", "score": 10.303335189819336},
{"id": "2", "score": 9.502392768859863},
],
[
{"id": "1", "score": 10.88800048828125},
{"id": "3", "score": 9.950843811035156},
{"id": "2", "score": 9.602447509765625},
],
]
Contributing
We welcome contributions! To get started:
- Install the development dependencies:
pip install "pylate[dev]"
- Run tests:
make test
- Format code with Ruff:
make ruff
- Build the documentation:
make livedoc
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file pylate-0.0.1.tar.gz
.
File metadata
- Download URL: pylate-0.0.1.tar.gz
- Upload date:
- Size: 37.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.10.14
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0f4b8b8b77630681b30e0ef88cf7dd8606d74955a4bfe5a5f1382cfbdf79a0f5 |
|
MD5 | 927c50dde55e96ee3cda2abcd22bceb4 |
|
BLAKE2b-256 | f3a7d2161eecd95af03eb16981ce81a87bf2a34077cce99d9e9e0093f241bf13 |