Skip to main content

Dense Retriever

Project description

🦮 Golden Retriever

PyTorch Lightning Code style: black vscode

release gh-status

How to use

Install the library from PyPI:

pip install goldenretriever

or from source:

git clone https://github.com/Riccorl/golden-retriever.git
cd goldenretriever
pip install -e .

Usage

How to run an experiment

Training

Here a simple example on how to train a DPR-like Retriever on the NQ dataset. First download the dataset from (DPR)[]. The run the following code:

from goldenretriever.trainer import Trainer
from goldenretriever import GoldenRetriever
from goldenretriever.data.datasets import InBatchNegativesDataset

# create a retriever
retriever = GoldenRetriever(
    question_encoder="intfloat/e5-small-v2",
    passage_encoder="intfloat/e5-small-v2"
)

# create a dataset
train_dataset = InBatchNegativesDataset(
    name="webq_train",
    path="path/to/webq_train.json",
    tokenizer=retriever.question_tokenizer,
    question_batch_size=64,
    passage_batch_size=400,
    max_passage_length=64,
    shuffle=True,
)
val_dataset = InBatchNegativesDataset(
    name="webq_dev",
    path="path/to/webq_dev.json",
    tokenizer=retriever.question_tokenizer,
    question_batch_size=64,
    passage_batch_size=400,
    max_passage_length=64,
)

trainer = Trainer(
    retriever=retriever,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    max_steps=25_000,
    wandb_online_mode=True,
    wandb_project_name="golden-retriever-dpr",
    wandb_experiment_name="e5-small-webq",
    max_hard_negatives_to_mine=5,
)

# start training
trainer.train()

Evaluation

from goldenretriever.trainer import Trainer
from goldenretriever import GoldenRetriever
from goldenretriever.data.datasets import InBatchNegativesDataset

retriever = GoldenRetriever(
  question_encoder="",
  document_index="",
  device="cuda",
  precision="16",
)

test_dataset = InBatchNegativesDataset(
  name="test",
  path="",
  tokenizer=retriever.question_tokenizer,
  question_batch_size=64,
  passage_batch_size=400,
  max_passage_length=64,
)

trainer = Trainer(
  retriever=retriever,
  test_dataset=test_dataset,
  log_to_wandb=False,
  top_k=[20, 100]
)

trainer.test()

Inference

from goldenretriever import GoldenRetriever

retriever = GoldenRetriever(
    question_encoder="path/to/question/encoder",
    passage_encoder="path/to/passage/encoder",
    document_index="path/to/document/index"
)

# retrieve documents
retriever.retrieve("What is the capital of France?", k=5)

Data format

Input data

The retriever expects a jsonl file similar to DPR:

[
  {
	"question": "....",
	"answers": ["...", "...", "..."],
	"positive_ctxs": [{
		"title": "...",
		"text": "...."
	}],
	"negative_ctxs": ["..."],
	"hard_negative_ctxs": ["..."]
  },
  ...
]

Index data

The document to index can be either a jsonl file or a tsv file similar to DPR:

  • jsonl: each line is a json object with the following keys: id, text, metadata
  • tsv: each line is a tab-separated string with the id and text column, followed by any other column that will be stored in the metadata field

jsonl example:

[
  {
    "id": "...",
    "text": "...",
    "metadata": ["{...}"]
  },
  ...
]

tsv example:

id \t text \t any other column
...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

goldenretriever-core-0.9.0.tar.gz (79.1 kB view hashes)

Uploaded Source

Built Distribution

goldenretriever_core-0.9.0-py3-none-any.whl (105.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page