Skip to main content

SPRINT: A Unified Toolkit for Evaluating and Demystifying Zero-shot Neural Sparse Retrieval

Project description

A Unified Repository to evaluate diverse state-of-the-art neural sparse-retrieval methods in one-click.

Getting Started

This repo is backed by Pyserini/Anserini, both which relies on Java. To make all the things eaiser, we recommend to follow the steps below via conda:

#### Create a new conda environment using conda ####
$ conda create -n sprint_env python=3.8
$ conda activate sprint_env

# Install JDK 11 via conda
$ conda install -c conda-forge openjdk=11

# Install Pyserini, BEIR using PyPI
$ pip install pyserini
$ pip install beir

#### Git clone this repository
$ git clone https://github.com/thakur-nandan/sprint.git
$ cd sprint
$ pip install -e .
conda env create -f environment.yml  # The Java/JDK dependency will also be installed by running this

This will create a conda environment named sparse-retrieval. So if you want other names, please change the name argument in environment.yml.

To install this repo, just go into the repo and do: (This is required to run the examples)

pip install -e .

Inference

Quick start

For a quick start, we can go to the example for evaluating SPLADE (distilsplade_max) on the BeIR/SciFact dataset:

cd examples/inference/distilsplade_max/beir_scifact
bash all_in_one.sh

This will go over the whole pipeline and give the final evaluation results in beir_scifact-distilsplade_max-quantized/evaluation/metrics.json:

Results: distilsplade_max on BeIR/SciFact
   cat beir_scifact-distilsplade_max-quantized/evaluation/metrics.json 
   # {
   #     "nDCG": {
   #         "NDCG@1": 0.60333,
   #         "NDCG@3": 0.65969,
   #         "NDCG@5": 0.67204,
   #         "NDCG@10": 0.6925,
   #         "NDCG@100": 0.7202,
   #         "NDCG@1000": 0.72753
   #     },
   #     "MAP": {
   #         "MAP@1": 0.57217,
   #     ...
   # }

Or if you like running python directly, just run the code snippet below for evaluating castorini/unicoil-noexp-msmarco-passage on BeIR/SciFact:

from sprint.inference import aio


if __name__ == '__main__':  # aio.run can only be called within __main__
    aio.run(
        encoder_name='unicoil',
        ckpt_name='castorini/unicoil-noexp-msmarco-passage',
        data_name='beir/scifact',
        gpus=[0, 1],
        output_dir='beir_scifact-unicoil_noexp',
        do_quantization=True,
        quantization_method='range-nbits',  # So the doc term weights will be quantized by `(term_weights / 5) * (2 ** 8)`
        original_score_range=5,
        quantization_nbits=8,
        original_query_format='beir',
        topic_split='test'
    )
    # You would get "NDCG@10": 0.68563

Step by step

One can also run the above process in 6 separate steps under the step_by_step folder:

  1. encode: Encode documents into term weights by multiprocessing on mutliple GPUs;
  2. quantize: Quantize the document term weights into integers (can be scaped);
  3. index: Index the term weights in to Lucene index (backended by Pyserini);
  4. reformat: Reformat the queries file (e.g. the ones from BeIR) into the Pyserini format;
  5. search: Retrieve the relevant documents (backended by Pyserini);
  6. evaluate: Evaluate the results against a certain labeled data, e.g.the qrels used in BeIR (backended by BeIR)

Currently it directly supports methods (with reproduction verified):

Currently it supports data formats (by downloading automatically):

  • BeIR

Other models and data (formats) will be added.

Custom encoders

To add a custom encoder, one can refer to the example examples/inference/custom_encoder/beir_scifact, where distilsplade_max is evaluated on BeIR/SciFact with stopwords filtered out.

In detail, one just needs to define your custom encoder class and write a new encoder builder function:

from typing import Dict, List
from pyserini.encode import QueryEncoder, DocumentEncoder

class CustomQueryEncoder(QueryEncoder):

    def encode(self, text, **kwargs) -> Dict[str, float]:
        # Just an example:
        terms = text.split()
        term_weights = {term: 1 for term in terms}
        return term_weights  # Dict object, where keys/values are terms/term scores, resp.

class CustomDocumentEncoder(DocumentEncoder):

    def encode(self, texts, **kwargs) -> List[Dict[str, float]]:
        # Just an example:
        term_weights_batch = []
        for text in texts:
            terms = text.split()
            term_weights = {term: 1 for term in terms}
            term_weights_batch.append(term_weights)
        return term_weights_batch 

def custom_encoder_builder(ckpt_name, etype, device='cpu'):
    if etype == 'query':
        return CustomQueryEncoder(ckpt_name, device=device)        
    elif etype == 'document':
        return CustomDocumentEncoder(ckpt_name, device=device)
    else:
        raise ValueError

Then register custom_encoder_builder with sprint.inference.encoder_builders.register before usage:

from sprint.inference.encoder_builders import register

register('custom_encoder_builder', custom_encoder_builder)

Training (Experimental)

Will be added.

Contacts

The main contributors of this repository are:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sprint-toolkit-0.0.1.tar.gz (8.0 kB view details)

Uploaded Source

File details

Details for the file sprint-toolkit-0.0.1.tar.gz.

File metadata

  • Download URL: sprint-toolkit-0.0.1.tar.gz
  • Upload date:
  • Size: 8.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.12

File hashes

Hashes for sprint-toolkit-0.0.1.tar.gz
Algorithm Hash digest
SHA256 21c408ed120c2e4d5ac6b8b8e2502ca13b7f05544dc4d4a56c884af7b7d2e0df
MD5 14d8521c2596c6877cb3e42d3686e8b3
BLAKE2b-256 b836d6c3daf380b13215e6f94aa77cb30600ce8c5f5538098c5f51dbe13ca43a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page