Skip to main content

Sparse Embeddings for Neural Search.

Project description

SparsEmbed - Splade

Neural search

This repository presents an unofficial replication of both models Splade and SparseEmbed with are state of the art models in information retrieval:

Note: This project is currently a work in progress. Splade Model is ready to use but I'm working on SparseEmbed. 🔨🧹

Installation

We can install sparsembed using:

pip install sparsembed

If we plan to evaluate our model while training install:

pip install "sparsembed[eval]"

Retriever

Splade

We can initialize a Splade Retriever directly from the splade_v2_max checkpoint available on HuggingFace. Retrievers are based on PyTorch sparse matrixes, stored in memory and accelerated with GPU. We can reduce the number of activated tokens via the n_tokens parameter in order to reduce the memory usage of those sparse matrixes.

from sparsembed import model, retrieve
from transformers import AutoModelForMaskedLM, AutoTokenizer

device = "cpu"

batch_size = 10

# List documents to index:
documents = [
 {'id': 0,
  'title': 'Paris',
  'url': 'https://en.wikipedia.org/wiki/Paris',
  'text': 'Paris is the capital and most populous city of France.'},
 {'id': 1,
  'title': 'Paris',
  'url': 'https://en.wikipedia.org/wiki/Paris',
  'text': "Since the 17th century, Paris has been one of Europe's major centres of science, and arts."},
 {'id': 2,
  'title': 'Paris',
  'url': 'https://en.wikipedia.org/wiki/Paris',
  'text': 'The City of Paris is the centre and seat of government of the region and province of Île-de-France.'
}]

model = model.Splade(
    model=AutoModelForMaskedLM.from_pretrained("naver/splade_v2_max").to(device),
    tokenizer=AutoTokenizer.from_pretrained("naver/splade_v2_max"),
    device=device
)

retriever = retrieve.SpladeRetriever(
    key="id", # Key identifier of each document.
    on=["title", "text"], # Fields to search.
    model=model # Splade retriever.
)

retriever = retriever.add(
    documents=documents,
    batch_size=batch_size,
    k_tokens=256, # Number of activated tokens.
)

retriever(
    ["paris", "Toulouse"], # Queries 
    k_tokens=20, # Maximum number of activated tokens.
    k=100, # Number of documents to retrieve.
    batch_size=batch_size
)
[[{'id': 0, 'similarity': 11.481657981872559},
  {'id': 2, 'similarity': 11.294965744018555},
  {'id': 1, 'similarity': 10.059721946716309}],
 [{'id': 0, 'similarity': 0.7379149198532104},
  {'id': 2, 'similarity': 0.6973429918289185},
  {'id': 1, 'similarity': 0.5428210496902466}]]

SparsEmbed

We can also initialize a retriever dedicated to SparseEmbed model. The checkpoint naver/splade_v2_max is not a SparseEmbed trained model so we should train one before using it as a retriever.

from sparsembed import model, retrieve
from transformers import AutoModelForMaskedLM, AutoTokenizer

device = "cpu"

batch_size = 10

# List documents to index:
documents = [
 {'id': 0,
  'title': 'Paris',
  'url': 'https://en.wikipedia.org/wiki/Paris',
  'text': 'Paris is the capital and most populous city of France.'},
 {'id': 1,
  'title': 'Paris',
  'url': 'https://en.wikipedia.org/wiki/Paris',
  'text': "Since the 17th century, Paris has been one of Europe's major centres of science, and arts."},
 {'id': 2,
  'title': 'Paris',
  'url': 'https://en.wikipedia.org/wiki/Paris',
  'text': 'The City of Paris is the centre and seat of government of the region and province of Île-de-France.'
}]

model = model.SparsEmbed(
    model=AutoModelForMaskedLM.from_pretrained("naver/splade_v2_max").to(device),
    tokenizer=AutoTokenizer.from_pretrained("naver/splade_v2_max"),
    device=device
)

retriever = retrieve.SparsEmbedRetriever(
    key="id", # Key identifier of each document.
    on=["title", "text"], # Fields to search.
    model=model # Splade retriever.
)

retriever = retriever.add(
    documents=documents,
    batch_size=batch_size,
    k_tokens=256, # Number of activated tokens.
)

retriever(
    ["paris", "Toulouse"], # Queries 
    k_tokens=20, # Maximum number of activated tokens.
    k=100, # Number of documents to retrieve.
    batch_size=batch_size
)

Training

Let's fine-tune Splade and SparsEmbed.

Dataset

Your training dataset must be made out of triples (anchor, positive, negative) where anchor is a query, positive is a document that is directly linked to the anchor and negative is a document that is not relevant for the anchor.

X = [
    ("anchor 1", "positive 1", "negative 1"),
    ("anchor 2", "positive 2", "negative 2"),
    ("anchor 3", "positive 3", "negative 3"),
]

Models

Both Splade and SparseEmbed models can be initialized from the AutoModelForMaskedLM pretrained models.

from transformers import AutoModelForMaskedLM, AutoTokenizer

model = model.Splade(
    model=AutoModelForMaskedLM.from_pretrained("naver/splade_v2_max").to(device),
    tokenizer=AutoTokenizer.from_pretrained("naver/splade_v2_max"),
    device=device,
)
from transformers import AutoModelForMaskedLM, AutoTokenizer

model = model.SparsEmbed(
    model=AutoModelForMaskedLM.from_pretrained("naver/splade_v2_max").to(device),
    tokenizer=AutoTokenizer.from_pretrained("naver/splade_v2_max"),
    embedding_size=64,
    k_tokens=96,
    device=device,
)

Splade

The following PyTorch code snippet illustrates the training loop to fine-tune Splade:

from transformers import AutoModelForMaskedLM, AutoTokenizer, optimization
from sparsembed import model, utils, train, retrieve, losses
import torch

device = "cpu" # cpu or cuda
batch_size = 8
epochs = 1 # Number of times the model will train over the whole dataset.

model = model.Splade(
    model=AutoModelForMaskedLM.from_pretrained("naver/splade_v2_max").to(device),
    tokenizer=AutoTokenizer.from_pretrained("naver/splade_v2_max"),
    device=device
)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

scheduler = optimization.get_linear_schedule_with_warmup(
    optimizer=optimizer, 
    num_warmup_steps=6000,
    num_training_steps=4_000_000,
)

flops_scheduler = losses.FlopsScheduler(weight=1e-4, steps=50_000) 

X = [
    ("anchor 1", "positive 1", "negative 1"),
    ("anchor 2", "positive 2", "negative 2"),
    ("anchor 3", "positive 3", "negative 3"),
]

for anchor, positive, negative in utils.iter(
        X,
        epochs=epochs,
        batch_size=batch_size,
        shuffle=True
    ):
        loss = train.train_splade(
            model=model,
            optimizer=optimizer,
            anchor=anchor,
            positive=positive,
            negative=negative,
            flops_loss_weight=flops_scheduler.get(),
        )

        scheduler.step()
        flops_scheduler.step()

documents, queries, qrels = utils.load_beir("scifact", split="test")

retriever = retrieve.SpladeRetriever(
    key="id",
    on=["title", "text"],
    model=model
)

retriever = retriever.add(
    documents=documents,
    batch_size=batch_size,
    k_tokens=96,
)

utils.evaluate(
    retriever=retriever,
    batch_size=batch_size,
    qrels=qrels,
    queries=queries,
    k=100,
    k_tokens=96,
    metrics=["map", "ndcg@10", "ndcg@10", "recall@10", "hits@10"]
)

SparsEmbed

The following PyTorch code snippet illustrates the training loop to fine-tune SparseEmbed:

from transformers import AutoModelForMaskedLM, AutoTokenizer, optimization
from sparsembed import model, utils, train, retrieve, losses
import torch

device = "cuda" # cpu or cuda
batch_size = 8
epochs = 1 # Number of times the model will train over the whole dataset.

model = model.SparsEmbed(
    model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device),
    tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"),
    device=device
)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

scheduler = optimization.get_linear_schedule_with_warmup(
    optimizer=optimizer, 
    num_warmup_steps=6000, # Number of warmup steps.
    num_training_steps=4_000_000 # Length training set.
)

flops_scheduler = losses.FlopsScheduler(weight=1e-4, steps=50_000)

X = [
    ("anchor 1", "positive 1", "negative 1"),
    ("anchor 2", "positive 2", "negative 2"),
    ("anchor 3", "positive 3", "negative 3"),
]

for anchor, positive, negative in utils.iter(
        X,
        epochs=epochs,
        batch_size=batch_size,
        shuffle=True
    ):
        loss = train.train_sparsembed(
            model=model,
            optimizer=optimizer,
            k_tokens=96,
            anchor=anchor,
            positive=positive,
            negative=negative,
            flops_loss_weight=flops_scheduler.get(),
            sparse_loss_weight=0.1,
        )

        scheduler.step()
        flops_scheduler.step()

documents, queries, qrels = utils.load_beir("scifact", split="test")

retriever = retrieve.SparsEmbedRetriever(
    key="id",
    on=["title", "text"],
    model=model
)

retriever = retriever.add(
    documents=documents,
    k_tokens=96,
    batch_size=batch_size
)

utils.evaluate(
    retriever=retriever,
    batch_size=batch_size,
    qrels=qrels,
    queries=queries,
    k=100,
    k_tokens=96,
    metrics=["map", "ndcg@10", "ndcg@10", "recall@10", "hits@10"]
)

Utils

Wen can visualize activated tokens:

model.decode(**model.encode(["deep learning, information retrieval, sparse models"]))
['deep sparse model retrieval information models depth fuzzy learning dense poor memory recall processing reading lacy include remember knowledge training heavy retrieve guide vague type small learn data']

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sparsembed-0.0.9.tar.gz (17.6 kB view details)

Uploaded Source

File details

Details for the file sparsembed-0.0.9.tar.gz.

File metadata

  • Download URL: sparsembed-0.0.9.tar.gz
  • Upload date:
  • Size: 17.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for sparsembed-0.0.9.tar.gz
Algorithm Hash digest
SHA256 5800bf1cd081bcc54e419038651a201490642c3da4bc1905e1bbe896180c0672
MD5 f89751ac829baf5461aeeac2ac21da63
BLAKE2b-256 410388548276d42920bc510a441590c14aeb6577a26902bd3d8bf2fc3f34975f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page