Skip to main content

Asyncval: A toolkit for asynchronously validating dense retriever checkpoints during training.

Project description

Asyncval

Asyncval: A toolkit for asynchronously validating dense retriever checkpoints during training.

Validating dense retriever checkpoints during training is time-consuming. Asyncval is a toolkit that aims to accelerate this process. Asyncval decouples the validation loop from the training loop, uses another GPU to automatically validate new DR checkpoints generated during training, and thus permits to perform validation asynchronously from training.

.im

Installation

For customized dense retriever encoders, clone this repo and install as editable,

git clone https://github.com/ielab/asyncval.git
cd asyncval
pip install --editable .

Note: The current code base has been tested with, torch==1.10.0, transformers==4.5.1, datasets==1.16.1, faiss-cpu==1.7.1, python==3.7, ir_measures==0.2.3

Preparation

To be able to use Asycval to validate your dense retreiver checkpoints, there are only two things you need to prepare:

(1) Prepare the corpus and the validation query files

Corpus and validation query files are to be in JSON format; specifically each line in the file is a JSON object that represents a passage or a query:

{"text_id": str, "text": List[int]}

where text_id is a unique id for a passage or query in the corpus or query file, and text is a list of integers representing the token ids of the passage or query.

Note the text field can also be the raw string of passage or query, however, we suggest Asyncval’s users to pre-tokenize their passages and queries and thus supply only token ids because Pre-tokenizing all text at once can speed up validation as there is no need to tokenize the same query and passage for each model checkpoint.

(2) Rewrite DenseModel Class in src/asyncval/modeling.py. (Tevatron users can skip this step)

After the corpus and query file are prepared, a python class called DenseModel needs to be defined;

Rewrite the __init__ constructor method to intital your dense retriever encoder model. This method should take the checkpoint path and the pre-defined AsyncvalArguments as inputs. In AsyncvalArguments we pre-defined several useful arguments that can be passed via command line.

Rewrite encode_passage and encode_query methods to encode queries and passages into dense vectors. These two methods should take a dictionary of key and values as input, where keys are Huggingface transformers input keys such as input_ids and attention_mask and values are batched token ids and an attention mask matrix. Researchers need to implement and override these two methods by (i) taking the input dictionary and feed it to the DR model that they built in the constructor method, and (ii) returning batch dense vectors.

Runing Asyncval validation

After you get the corpus and query files and DenseModel class prepared. You can simply run Asyncval with the following command:

python -m asyncval \
	--query_file List[str] \
	--candidate_file str \
	--ckpts_dir str \
	--tokenizer_name_or_path str \
	--qrel_file str \
	--output_dir str

Arguments Description

name desciption type default
query_file The path to the pre-tokenized query JSON file. Multiple files can be provided. List[str] required
candidate_dir The path to the folder that saves the pre-tokenized corpus JSON files. str required
ckpts_dir The path to the folder that saves DR checkpoints. str required
tokenizer_name_or_path The path or name to the Huggingface tokenizer. (for padding and attention masking) str required
q_max_len The maximum number of query token. int 32
p_max_len The maximum number of document token. int 128
qrel_file The path to the TREC format qrel file. str required
run_name A descriptor for the run. Typically used for wandb. str None
write_run Whether to write run files to disk. bool True
output_dir The path to save checkpoint run files. str required
logging_dir Tensorboard log dir. str None
metrics The list of ir_mesures metrics for validation. List[str] ['RR@10','nDCG@10']
per_device_eval_batch_size The batch size for encoding. int 8
retrieve_batch_size The batch size for FAISS retrieval. int 64
fp16 Whether to use fp16 16-bit (mixed) precision encoding instead of 32-bit encoding. bool False
report_to The list of integrations to report the results and logs to. str None
device Set which GPU device for running asyncval. str 'cuda:0'
depth Retrival ranking depth (cutoff). int 100
cache_dir Folder for saving encoding dataset cache. str None

Use splitter to sample subset of corpus for fast validation

To further speed up checkpoint validation, we also provide a feature to sample from the full corpus a subset, thus to avoid encoding all corpus passages for each checkpoint. The subset is sampled based on a given run file of the validation query set, such as the BM25 results, and a TREC standard qrel file which provides gold (relevant) passages for each query in the validation set. To do this, the following command should be ran:

python -m asyncval.splitter \
		--candidate_dir str \
		--run_file str \
		--qrel_file str \
		--output_dir str \
		--depth int

where --candidate_dir is the path to folder that saves the pre-tokenized full corpus JSON files; --run_file is the path to the run file; --qrel_file is the path to the TREC qrel file; --output_dir is the path to the folder in which to save the JSON file for the subset; and --depth is the number of top passages to keep for each query in the run file. For example, setting --depth to 100 means that only the top 100 passages for each query are kept. This trades-off validation accuracy for speed. Then you can use the generated subset corpus file for validation. We found that, instead of BM25, using a strong dense retriever baseline to generate the subset can dramatically reduce the validation time of each checkpoint without loss of fidelity.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

asyncval-0.2.0.tar.gz (13.6 kB view details)

Uploaded Source

Built Distribution

asyncval-0.2.0-py3-none-any.whl (17.2 kB view details)

Uploaded Python 3

File details

Details for the file asyncval-0.2.0.tar.gz.

File metadata

  • Download URL: asyncval-0.2.0.tar.gz
  • Upload date:
  • Size: 13.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.24.0 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.51.0 importlib-metadata/4.11.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.8.5

File hashes

Hashes for asyncval-0.2.0.tar.gz
Algorithm Hash digest
SHA256 8b065f7b1c28520df8c67e27bd7a5ba926ddc4789b9a342e26952c3c044fe491
MD5 b8f68ebb58cbd963257c6c0cbe9dfa3d
BLAKE2b-256 43967c9ffe8a501743a2db4612fec982f2e1b6cd6714d7c131ddf86fb52e3a29

See more details on using hashes here.

File details

Details for the file asyncval-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: asyncval-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 17.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.24.0 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.51.0 importlib-metadata/4.11.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.8.5

File hashes

Hashes for asyncval-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1cec903f4d15caeece150db414f24331266eec62d1eb1a466205c129971fd29f
MD5 71fa9aaa136b8186a297bea9ed50140d
BLAKE2b-256 29955b4efa498f812133b7a642cae969d01261d4af84dde3edfcb22674e1aa9e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page