Sparse Embeddings for Neural Search.
Project description
SparsEmbed - Splade
Neural search
This repository presents an unofficial replication of the research papers:
-
SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking authored by Thibault Formal, Benjamin Piwowarski, Stéphane Clinchant, SIGIR 2021.
-
SparseEmbed: Learning Sparse Lexical Representations with Contextual Embeddings for Retrieval authored by Weize Kong, Jeffrey M. Dudek, Cheng Li, Mingyang Zhang, and Mike Bendersky, SIGIR 2023.
Note: This project is currently a work in progress and models are not ready to use. 🔨🧹
Installation
pip install sparsembed
If you plan to evaluate your model, install:
pip install "sparsembed[eval]"
Training
Dataset
Your training dataset must be made out of triples (anchor, positive, negative)
where anchor is a query, positive is a document that is directly linked to the anchor and negative is a document that is not relevant for the query.
X = [
("anchor 1", "positive 1", "negative 1"),
("anchor 2", "positive 2", "negative 2"),
("anchor 3", "positive 3", "negative 3"),
]
Models
Both Splade and SparseEmbed models can be initialized from the AutoModelForMaskedLM
pretrained models.
from transformers import AutoModelForMaskedLM, AutoTokenizer
model = model.Splade(
model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device),
tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"),
device=device,
)
from transformers import AutoModelForMaskedLM, AutoTokenizer
model = model.SparsEmbed(
model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device),
tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"),
embedding_size=64,
k_tokens=96,
device=device,
)
Splade
The following PyTorch code snippet illustrates the training loop to fine-tune Splade:
from transformers import AutoModelForMaskedLM, AutoTokenizer
from sparsembed import model, utils, train, retrieve
import torch
device = "cuda" # cpu
batch_size = 8
model = model.Splade(
model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device),
tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"),
device=device
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
flops_scheduler = losses.FlopsScheduler(weight=3e-5, steps=10_000)
X = [
("anchor 1", "positive 1", "negative 1"),
("anchor 2", "positive 2", "negative 2"),
("anchor 3", "positive 3", "negative 3"),
]
for anchor, positive, negative in utils.iter(
X,
epochs=1,
batch_size=batch_size,
shuffle=True
):
loss = train.train_splade(
model=model,
optimizer=optimizer,
anchor=anchor,
positive=positive,
negative=negative,
flops_loss_weight=flops_scheduler(),
in_batch_negatives=True,
)
documents, queries, qrels = utils.load_beir("scifact", split="test")
retriever = retrieve.SpladeRetriever(
key="id",
on=["title", "text"],
model=model
)
retriever = retriever.add(
documents=documents,
batch_size=batch_size,
k_tokens=96,
)
utils.evaluate(
retriever=retriever,
batch_size=batch_size,
qrels=qrels,
queries=queries,
k=100,
k_tokens=96,
metrics=["map", "ndcg@10", "ndcg@10", "recall@10", "hits@10"]
)
SparsEmbed
The following PyTorch code snippet illustrates the training loop to fine-tune SparseEmbed:
from transformers import AutoModelForMaskedLM, AutoTokenizer
from sparsembed import model, utils, train, retrieve
import torch
device = "cuda" # cpu
batch_size = 8
model = model.SparsEmbed(
model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device),
tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"),
device=device
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
flops_scheduler = losses.FlopsScheduler(weight=3e-5, steps=10_000)
X = [
("anchor 1", "positive 1", "negative 1"),
("anchor 2", "positive 2", "negative 2"),
("anchor 3", "positive 3", "negative 3"),
]
for anchor, positive, negative in utils.iter(
X,
epochs=1,
batch_size=batch_size,
shuffle=True
):
loss = train.train_sparsembed(
model=model,
optimizer=optimizer,
k_tokens=96,
anchor=anchor,
positive=positive,
negative=negative,
flops_loss_weight=flops_scheduler(),
sparse_loss_weight=0.1,
in_batch_negatives=True,
)
documents, queries, qrels = utils.load_beir("scifact", split="test")
retriever = retrieve.SparsEmbedRetriever(
key="id",
on=["title", "text"],
model=model
)
retriever = retriever.add(
documents=documents,
k_tokens=96,
batch_size=batch_size
)
utils.evaluate(
retriever=retriever,
batch_size=batch_size,
qrels=qrels,
queries=queries,
k=100,
k_tokens=96,
metrics=["map", "ndcg@10", "ndcg@10", "recall@10", "hits@10"]
)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file sparsembed-0.0.8.tar.gz
.
File metadata
- Download URL: sparsembed-0.0.8.tar.gz
- Upload date:
- Size: 15.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 31abec163b31925ac2b1bf453369d3f94168772c42dd012d6895503cab6b7a58 |
|
MD5 | 71d1545cb757e58ba56aa65a254c5866 |
|
BLAKE2b-256 | b9c5bd6fff3bca41ac7c853323adfad735bf7d3e352abfa7ccf309fbdb787168 |