Dense Retriever
Project description
🦮 Golden Retriever
How to use
Install the library from PyPi:
pip install goldenretriever-core
or from source:
git clone https://github.com/Riccorl/golden-retriever.git
cd golden-retriever
pip install -e .
FAISS
Install with optional dependencies for FAISS
FAISS pypi package is only available for CPU. If you want to use GPU, you need to install it from source or use the conda package.
For CPU:
pip install goldenretriever-core[faiss]
For GPU:
conda create -n goldenretriever python=3.11
conda activate goldenretriever
# install pytorch
conda install -y pytorch=2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia
# GPU
conda install -y -c pytorch -c nvidia faiss-gpu=1.8.0
# or GPU with NVIDIA RAFT
conda install -y -c pytorch -c nvidia -c rapidsai -c conda-forge faiss-gpu-raft=1.8.0
pip install goldenretriever-core
Usage
Golden Retriever is built on top of PyTorch Lightning and Hydra. To run an experiment, you need to create a configuration file and pass
it to the golden-retriever
command. Few examples are provided in the conf
folder.
Training
Here a simple example on how to train a DPR-like Retriever on the NQ dataset. First download the dataset from DPR. The run the following code:
golden-retriever train conf/nq-dpr.yaml
Evaluation
from goldenretriever.trainer import Trainer
from goldenretriever import GoldenRetriever
from goldenretriever.data.datasets import InBatchNegativesDataset
retriever = GoldenRetriever(
question_encoder="",
document_index="",
device="cuda",
precision="16",
)
test_dataset = InBatchNegativesDataset(
name="test",
path="",
tokenizer=retriever.question_tokenizer,
question_batch_size=64,
passage_batch_size=400,
max_passage_length=64,
)
trainer = Trainer(
retriever=retriever,
test_dataset=test_dataset,
log_to_wandb=False,
top_k=[20, 100]
)
trainer.test()
Distributed environment
Golden Retriever supports distributed training. For the moment, it is only possible to train on a single node with multiple GPUs and without model sharding, i.e.
only DDP and FSDP with NO_SHARD
strategy are supported.
To run a distributed training, just add the following keys to the configuration file:
devices: 4 # number of GPUs
# strategy: "ddp_find_unused_parameters_true" # DDP
# FSDP with NO_SHARD
strategy:
_target_: lightning.pytorch.strategies.FSDPStrategy
sharding_strategy: "NO_SHARD"
Inference
from goldenretriever import GoldenRetriever
retriever = GoldenRetriever(
question_encoder="path/to/question/encoder",
passage_encoder="path/to/passage/encoder",
document_index="path/to/document/index"
)
# retrieve documents
retriever.retrieve("What is the capital of France?", k=5)
Data format
Input data
The retriever expects a jsonl file similar to DPR:
[
{
"question": "....",
"answers": ["...", "...", "..."],
"positive_ctxs": [{
"title": "...",
"text": "...."
}],
"negative_ctxs": ["..."],
"hard_negative_ctxs": ["..."]
},
...
]
Index data
The document to index can be either a jsonl file or a tsv file similar to DPR:
jsonl
: each line is a json object with the following keys:id
,text
,metadata
tsv
: each line is a tab-separated string with theid
andtext
column, followed by any other column that will be stored in themetadata
field
jsonl example:
[
{
"id": "...",
"text": "...",
"metadata": ["{...}"]
},
...
]
tsv example:
id \t text \t any other column
...
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for goldenretriever-core-1.0.0.dev1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | d70fa2e4000896faebe05c7121f5766810e89044ffca088f81e033f5321a7eb2 |
|
MD5 | 03eb2f29a81e34c0839da4b55115a724 |
|
BLAKE2b-256 | 0b4c7df5cea42e513f8fbeca518ae49e90b1d83011023f6d79457d0cb740d55b |
Hashes for goldenretriever_core-1.0.0.dev1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 31a6965d8911d760828073f8e7cdc6c826857ac4caae526918dcb84a85d5ee87 |
|
MD5 | 64042df44703656076d35cd7ecffba78 |
|
BLAKE2b-256 | 4a734463b2edbe8f11c19102c0024ba4cbdd59097f48f8239dd5d78614aa9f9d |