Skip to main content

Dense Retriever

Project description

🦮 Golden Retriever

PyTorch Lightning Code style: black vscode

release gh-status

How to use

Install the library from PyPi:

pip install goldenretriever-core

or from source:

git clone https://github.com/Riccorl/golden-retriever.git
cd golden-retriever
pip install -e .

FAISS

Install with optional dependencies for FAISS

FAISS pypi package is only available for CPU. If you want to use GPU, you need to install it from source or use the conda package.

For CPU:

pip install goldenretriever-core[faiss]

For GPU:

conda create -n goldenretriever python=3.11
conda activate goldenretriever

# install pytorch
conda install -y pytorch=2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia

# GPU
conda install -y -c pytorch -c nvidia faiss-gpu=1.8.0
# or GPU with NVIDIA RAFT
conda install -y -c pytorch -c nvidia -c rapidsai -c conda-forge faiss-gpu-raft=1.8.0

pip install goldenretriever-core

Usage

Golden Retriever is built on top of PyTorch Lightning and Hydra. To run an experiment, you need to create a configuration file and pass it to the golden-retriever command. Few examples are provided in the conf folder.

Training

Here a simple example on how to train a DPR-like Retriever on the NQ dataset. First download the dataset from DPR. The run the following code:

golden-retriever train conf/nq-dpr.yaml

Evaluation

from goldenretriever.trainer import Trainer
from goldenretriever import GoldenRetriever
from goldenretriever.data.datasets import InBatchNegativesDataset

retriever = GoldenRetriever(
  question_encoder="",
  document_index="",
  device="cuda",
  precision="16",
)

test_dataset = InBatchNegativesDataset(
  name="test",
  path="",
  tokenizer=retriever.question_tokenizer,
  question_batch_size=64,
  passage_batch_size=400,
  max_passage_length=64,
)

trainer = Trainer(
  retriever=retriever,
  test_dataset=test_dataset,
  log_to_wandb=False,
  top_k=[20, 100]
)

trainer.test()

Distributed environment

Golden Retriever supports distributed training. For the moment, it is only possible to train on a single node with multiple GPUs and without model sharding, i.e. only DDP and FSDP with NO_SHARD strategy are supported.

To run a distributed training, just add the following keys to the configuration file:

devices: 4  # number of GPUs
# strategy: "ddp_find_unused_parameters_true"  # DDP
# FSDP with NO_SHARD
strategy:
  _target_: lightning.pytorch.strategies.FSDPStrategy
  sharding_strategy: "NO_SHARD"

Inference

from goldenretriever import GoldenRetriever

retriever = GoldenRetriever(
    question_encoder="path/to/question/encoder",
    passage_encoder="path/to/passage/encoder",
    document_index="path/to/document/index"
)

# retrieve documents
retriever.retrieve("What is the capital of France?", k=5)

Data format

Input data

The retriever expects a jsonl file similar to DPR:

[
  {
  "question": "....",
  "answers": ["...", "...", "..."],
  "positive_ctxs": [{
    "title": "...",
    "text": "...."
  }],
  "negative_ctxs": ["..."],
  "hard_negative_ctxs": ["..."]
  },
  ...
]

Index data

The document to index can be either a jsonl file or a tsv file similar to DPR:

  • jsonl: each line is a json object with the following keys: id, text, metadata
  • tsv: each line is a tab-separated string with the id and text column, followed by any other column that will be stored in the metadata field

jsonl example:

[
  {
    "id": "...",
    "text": "...",
    "metadata": ["{...}"]
  },
  ...
]

tsv example:

id \t text \t any other column
...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

goldenretriever_core-1.0.0-py3-none-any.whl (125.9 kB view details)

Uploaded Python 3

File details

Details for the file goldenretriever_core-1.0.0-py3-none-any.whl.

File metadata

File hashes

Hashes for goldenretriever_core-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 34eacfd41ddbef43e9ed13443f24b52e2be756e56497ffa4da9d76f3171cd273
MD5 d7499ee1e5ca822123da05dfe19daaae
BLAKE2b-256 a84858860e57192aa677a8a77579467eaa441c487ce063a57f841431534f2582

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page