Sparse Embeddings for Neural Search.
Project description
SparsEmbed - Splade
Neural search
This repository presents an unofficial replication of both models Splade and SparseEmbed with are state of the art models in information retrieval:
-
SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking authored by Thibault Formal, Benjamin Piwowarski, Stéphane Clinchant, SIGIR 2021.
-
SPLADE v2: Sparse Lexical and Expansion Model for Information Retrieval authored by Thibault Formal, Carlos Lassance, Benjamin Piwowarski, Stéphane Clinchant, SIGIR 2022.
-
SparseEmbed: Learning Sparse Lexical Representations with Contextual Embeddings for Retrieval authored by Weize Kong, Jeffrey M. Dudek, Cheng Li, Mingyang Zhang, and Mike Bendersky, SIGIR 2023.
Note: This project is currently a work in progress. Splade Model is ready to use but I'm working on SparseEmbed. 🔨🧹
Installation
We can install sparsembed using:
pip install sparsembed
If we plan to evaluate our model while training install:
pip install "sparsembed[eval]"
Retriever
Splade
We can initialize a Splade Retriever directly from the splade_v2_max
checkpoint available on HuggingFace. Retrievers are based on PyTorch sparse matrixes, stored in memory and accelerated with GPU. We can reduce the number of activated tokens via the n_tokens
parameter in order to reduce the memory usage of those sparse matrixes.
from sparsembed import model, retrieve
from transformers import AutoModelForMaskedLM, AutoTokenizer
device = "cpu"
batch_size = 10
# List documents to index:
documents = [
{'id': 0,
'title': 'Paris',
'url': 'https://en.wikipedia.org/wiki/Paris',
'text': 'Paris is the capital and most populous city of France.'},
{'id': 1,
'title': 'Paris',
'url': 'https://en.wikipedia.org/wiki/Paris',
'text': "Since the 17th century, Paris has been one of Europe's major centres of science, and arts."},
{'id': 2,
'title': 'Paris',
'url': 'https://en.wikipedia.org/wiki/Paris',
'text': 'The City of Paris is the centre and seat of government of the region and province of Île-de-France.'
}]
model = model.Splade(
model=AutoModelForMaskedLM.from_pretrained("naver/splade_v2_max").to(device),
tokenizer=AutoTokenizer.from_pretrained("naver/splade_v2_max"),
device=device
)
retriever = retrieve.SpladeRetriever(
key="id", # Key identifier of each document.
on=["title", "text"], # Fields to search.
model=model # Splade retriever.
)
retriever = retriever.add(
documents=documents,
batch_size=batch_size,
k_tokens=256, # Number of activated tokens.
)
retriever(
["paris", "Toulouse"], # Queries
k_tokens=20, # Maximum number of activated tokens.
k=100, # Number of documents to retrieve.
batch_size=batch_size
)
[[{'id': 0, 'similarity': 11.481657981872559},
{'id': 2, 'similarity': 11.294965744018555},
{'id': 1, 'similarity': 10.059721946716309}],
[{'id': 0, 'similarity': 0.7379149198532104},
{'id': 2, 'similarity': 0.6973429918289185},
{'id': 1, 'similarity': 0.5428210496902466}]]
SparsEmbed
We can also initialize a retriever dedicated to SparseEmbed model. The checkpoint naver/splade_v2_max
is not a SparseEmbed trained model so we should train one before using it as a retriever.
from sparsembed import model, retrieve
from transformers import AutoModelForMaskedLM, AutoTokenizer
device = "cpu"
batch_size = 10
# List documents to index:
documents = [
{'id': 0,
'title': 'Paris',
'url': 'https://en.wikipedia.org/wiki/Paris',
'text': 'Paris is the capital and most populous city of France.'},
{'id': 1,
'title': 'Paris',
'url': 'https://en.wikipedia.org/wiki/Paris',
'text': "Since the 17th century, Paris has been one of Europe's major centres of science, and arts."},
{'id': 2,
'title': 'Paris',
'url': 'https://en.wikipedia.org/wiki/Paris',
'text': 'The City of Paris is the centre and seat of government of the region and province of Île-de-France.'
}]
model = model.SparsEmbed(
model=AutoModelForMaskedLM.from_pretrained("naver/splade_v2_max").to(device),
tokenizer=AutoTokenizer.from_pretrained("naver/splade_v2_max"),
device=device
)
retriever = retrieve.SparsEmbedRetriever(
key="id", # Key identifier of each document.
on=["title", "text"], # Fields to search.
model=model # Splade retriever.
)
retriever = retriever.add(
documents=documents,
batch_size=batch_size,
k_tokens=256, # Number of activated tokens.
)
retriever(
["paris", "Toulouse"], # Queries
k_tokens=20, # Maximum number of activated tokens.
k=100, # Number of documents to retrieve.
batch_size=batch_size
)
Training
Let's fine-tune Splade and SparsEmbed.
Dataset
Your training dataset must be made out of triples (anchor, positive, negative)
where anchor is a query, positive is a document that is directly linked to the anchor and negative is a document that is not relevant for the anchor.
X = [
("anchor 1", "positive 1", "negative 1"),
("anchor 2", "positive 2", "negative 2"),
("anchor 3", "positive 3", "negative 3"),
]
Models
Both Splade and SparseEmbed models can be initialized from the AutoModelForMaskedLM
pretrained models.
from transformers import AutoModelForMaskedLM, AutoTokenizer
model = model.Splade(
model=AutoModelForMaskedLM.from_pretrained("naver/splade_v2_max").to(device),
tokenizer=AutoTokenizer.from_pretrained("naver/splade_v2_max"),
device=device,
)
from transformers import AutoModelForMaskedLM, AutoTokenizer
model = model.SparsEmbed(
model=AutoModelForMaskedLM.from_pretrained("naver/splade_v2_max").to(device),
tokenizer=AutoTokenizer.from_pretrained("naver/splade_v2_max"),
embedding_size=64,
k_tokens=96,
device=device,
)
Splade
The following PyTorch code snippet illustrates the training loop to fine-tune Splade:
from transformers import AutoModelForMaskedLM, AutoTokenizer, optimization
from sparsembed import model, utils, train, retrieve, losses
import torch
device = "cpu" # cpu or cuda
batch_size = 8
epochs = 1 # Number of times the model will train over the whole dataset.
model = model.Splade(
model=AutoModelForMaskedLM.from_pretrained("naver/splade_v2_max").to(device),
tokenizer=AutoTokenizer.from_pretrained("naver/splade_v2_max"),
device=device
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scheduler = optimization.get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=6000,
num_training_steps=4_000_000,
)
flops_scheduler = losses.FlopsScheduler(weight=1e-4, steps=50_000)
X = [
("anchor 1", "positive 1", "negative 1"),
("anchor 2", "positive 2", "negative 2"),
("anchor 3", "positive 3", "negative 3"),
]
for anchor, positive, negative in utils.iter(
X,
epochs=epochs,
batch_size=batch_size,
shuffle=True
):
loss = train.train_splade(
model=model,
optimizer=optimizer,
anchor=anchor,
positive=positive,
negative=negative,
flops_loss_weight=flops_scheduler.get(),
)
scheduler.step()
flops_scheduler.step()
documents, queries, qrels = utils.load_beir("scifact", split="test")
retriever = retrieve.SpladeRetriever(
key="id",
on=["title", "text"],
model=model
)
retriever = retriever.add(
documents=documents,
batch_size=batch_size,
k_tokens=96,
)
utils.evaluate(
retriever=retriever,
batch_size=batch_size,
qrels=qrels,
queries=queries,
k=100,
k_tokens=96,
metrics=["map", "ndcg@10", "ndcg@10", "recall@10", "hits@10"]
)
SparsEmbed
The following PyTorch code snippet illustrates the training loop to fine-tune SparseEmbed:
from transformers import AutoModelForMaskedLM, AutoTokenizer, optimization
from sparsembed import model, utils, train, retrieve, losses
import torch
device = "cuda" # cpu or cuda
batch_size = 8
epochs = 1 # Number of times the model will train over the whole dataset.
model = model.SparsEmbed(
model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device),
tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"),
device=device
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scheduler = optimization.get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=6000, # Number of warmup steps.
num_training_steps=4_000_000 # Length training set.
)
flops_scheduler = losses.FlopsScheduler(weight=1e-4, steps=50_000)
X = [
("anchor 1", "positive 1", "negative 1"),
("anchor 2", "positive 2", "negative 2"),
("anchor 3", "positive 3", "negative 3"),
]
for anchor, positive, negative in utils.iter(
X,
epochs=epochs,
batch_size=batch_size,
shuffle=True
):
loss = train.train_sparsembed(
model=model,
optimizer=optimizer,
k_tokens=96,
anchor=anchor,
positive=positive,
negative=negative,
flops_loss_weight=flops_scheduler.get(),
sparse_loss_weight=0.1,
)
scheduler.step()
flops_scheduler.step()
documents, queries, qrels = utils.load_beir("scifact", split="test")
retriever = retrieve.SparsEmbedRetriever(
key="id",
on=["title", "text"],
model=model
)
retriever = retriever.add(
documents=documents,
k_tokens=96,
batch_size=batch_size
)
utils.evaluate(
retriever=retriever,
batch_size=batch_size,
qrels=qrels,
queries=queries,
k=100,
k_tokens=96,
metrics=["map", "ndcg@10", "ndcg@10", "recall@10", "hits@10"]
)
Utils
Wen can visualize activated tokens:
model.decode(**model.encode(["deep learning, information retrieval, sparse models"]))
['deep sparse model retrieval information models depth fuzzy learning dense poor memory recall processing reading lacy include remember knowledge training heavy retrieve guide vague type small learn data']
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file sparsembed-0.0.9.tar.gz
.
File metadata
- Download URL: sparsembed-0.0.9.tar.gz
- Upload date:
- Size: 17.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5800bf1cd081bcc54e419038651a201490642c3da4bc1905e1bbe896180c0672 |
|
MD5 | f89751ac829baf5461aeeac2ac21da63 |
|
BLAKE2b-256 | 410388548276d42920bc510a441590c14aeb6577a26902bd3d8bf2fc3f34975f |