SPRINT: A Unified Toolkit for Evaluating and Demystifying Zero-shot Neural Sparse Retrieval
Project description
SPRINT provides a unified repository to easily evaluate diverse state-of-the-art neural (BERT-based) sparse-retrieval models.
SPRINT toolkit allows you to easily search or evaluate any neural sparse retriever across any dataset in the BEIR benchmark (or your own dataset). The toolkit is built around as a useful wrapper around Pyserini. The toolkit provides evaluation of seven diverse (neural) sparse retrieval models: SPLADEv2, BT-SPLADE-L, uniCOIL, TILDEv2, DeepImpact, DocT5query and SPARTA.
If you want to read more about the SPRINT toolkit, or wish to know which model to use, please refer to our paper for more details:
- SPRINT: A Unified Toolkit for Evaluating and Demystifying Zero-shot Neural Sparse Retrieval (Accepted at SIGIR'23 Resource Track)
:runner: Getting Started
SPRINT is backed by Pyserini which relies on Java. To make the installation eaiser, we recommend to follow the steps below via conda
:
#### Create a new conda environment using conda ####
$ conda create -n sprint_env python=3.8
$ conda activate sprint_env
# Install JDK 11 via conda
$ conda install -c conda-forge openjdk=11
# Install SPRINT toolkit using PyPI
$ pip install sprint-toolkit
:runner: Quickstart with SPRINT Toolkit
Quick start
For a quick start, we can go to the example for evaluating SPLADE (distilsplade_max
) on the BeIR/SciFact dataset:
cd examples/inference/distilsplade_max/beir_scifact
bash all_in_one.sh
This will go over the whole pipeline and give the final evaluation results in beir_scifact-distilsplade_max-quantized/evaluation/metrics.json
:
Results: distilsplade_max on BeIR/SciFact
cat beir_scifact-distilsplade_max-quantized/evaluation/metrics.json
# {
# "nDCG": {
# "NDCG@1": 0.60333,
# "NDCG@3": 0.65969,
# "NDCG@5": 0.67204,
# "NDCG@10": 0.6925,
# "NDCG@100": 0.7202,
# "NDCG@1000": 0.72753
# },
# "MAP": {
# "MAP@1": 0.57217,
# ...
# }
Or if you like running python directly, just run the code snippet below for evaluating castorini/unicoil-noexp-msmarco-passage
on BeIR/SciFact
:
from sprint.inference import aio
if __name__ == '__main__': # aio.run can only be called within __main__
aio.run(
encoder_name='unicoil',
ckpt_name='castorini/unicoil-noexp-msmarco-passage',
data_name='beir/scifact',
gpus=[0, 1],
output_dir='beir_scifact-unicoil_noexp',
do_quantization=True,
quantization_method='range-nbits', # So the doc term weights will be quantized by `(term_weights / 5) * (2 ** 8)`
original_score_range=5,
quantization_nbits=8,
original_query_format='beir',
topic_split='test'
)
# You would get "NDCG@10": 0.68563
Step by step
One can also run the above process in 6 separate steps under the step_by_step folder:
- encode: Encode documents into term weights by multiprocessing on mutliple GPUs;
- quantize: Quantize the document term weights into integers (can be scaped);
- index: Index the term weights in to Lucene index (backended by Pyserini);
- reformat: Reformat the queries file (e.g. the ones from BeIR) into the Pyserini format;
- search: Retrieve the relevant documents (backended by Pyserini);
- evaluate: Evaluate the results against a certain labeled data, e.g.the qrels used in BeIR (backended by BeIR)
Currently it directly supports methods (with reproduction verified):
- uniCOIL;
- SPLADE: Go to examples/inference/distilsplade_max/beir_scifact for fast reproducing
distilsplade_max
on SciFact; - SPARTA;
- TILDEv2: Go to examples/inference/tildev2-noexp/trecdl2019 for fast reproducing
ielab/TILDEv2-noExp
reranking on TREC-DL 2019; - DeepImpact
Currently it supports data formats (by downloading automatically):
- BeIR
Other models and data (formats) will be added.
Custom encoders
To add a custom encoder, one can refer to the example examples/inference/custom_encoder/beir_scifact, where distilsplade_max
is evaluated on BeIR/SciFact
with stopwords filtered out.
In detail, one just needs to define your custom encoder class and write a new encoder builder function:
from typing import Dict, List
from pyserini.encode import QueryEncoder, DocumentEncoder
class CustomQueryEncoder(QueryEncoder):
def encode(self, text, **kwargs) -> Dict[str, float]:
# Just an example:
terms = text.split()
term_weights = {term: 1 for term in terms}
return term_weights # Dict object, where keys/values are terms/term scores, resp.
class CustomDocumentEncoder(DocumentEncoder):
def encode(self, texts, **kwargs) -> List[Dict[str, float]]:
# Just an example:
term_weights_batch = []
for text in texts:
terms = text.split()
term_weights = {term: 1 for term in terms}
term_weights_batch.append(term_weights)
return term_weights_batch
def custom_encoder_builder(ckpt_name, etype, device='cpu'):
if etype == 'query':
return CustomQueryEncoder(ckpt_name, device=device)
elif etype == 'document':
return CustomDocumentEncoder(ckpt_name, device=device)
else:
raise ValueError
Then register custom_encoder_builder
with sprint.inference.encoder_builders.register
before usage:
from sprint.inference.encoder_builders import register
register('custom_encoder_builder', custom_encoder_builder)
Training (Experimental)
Will be added.
Contacts
The main contributors of this repository are:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file sprint-toolkit-0.0.2.tar.gz
.
File metadata
- Download URL: sprint-toolkit-0.0.2.tar.gz
- Upload date:
- Size: 48.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 37f70f0126c174b7f6264c3b684061555ecee1e0630e9a235c862f1162ab73a5 |
|
MD5 | cb4ac27d959d3326c29937f4bc74f829 |
|
BLAKE2b-256 | 2167f4986b2e162783ea3d58959a0278131b8619f0240174cde1776c360b881d |