Skip to main content

A package for training and doing inference with contrastive learning with multiple GPUs (Pytorch-DDP).

Project description

clddp: Contrastive Learning with Distributed Data Parallel

This Python package provides an implementation for constrastive learning with multiple GPUs (i.e. Distributed-Data-Parallel), with a special focus on neural retrieval.

Installation

pip install -U clddp

If ColBERT is going to be used, please install its package additionally:

pip install git+https://github.com/stanford-futuredata/ColBERT.git@21b460a606bed606e8a7fa105ada36b18e8084ec

Quick Start

Please have a look at the examples for a quick start. For example, one can run the multi-GPU training with the following script:

Training example

dataset="fiqa"
export DATA_DIR="data"
mkdir -p $DATA_DIR
export DATASET_PATH="$DATA_DIR/$dataset"
if [ ! -f "$DATASET_PATH.zip" ]; then
    wget https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/$dataset.zip -P $DATA_DIR
fi
if [ ! -d "$DATASET_PATH" ]; then
    unzip $DATASET_PATH.zip -d $DATA_DIR
fi
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export WANDB_MODE="online"
export CHECKPOINT_DIR="checkpoints"
export CLI_ARGS="
--project="clddp_examples"
--checkpoint_dir=$CHECKPOINT_DIR
--query_model_name_or_path="sentence-transformers/msmarco-distilbert-base-tas-b"
--shared_encoder=True
--sep=blank
--pooling=cls
--similarity_function=dot_product
--query_max_length=350
--passage_max_length=350
--sim_scale=1.0
--fp16=True
--train_data=$DATASET_PATH
--train_dataloader=beir
--num_negatives=0
--dev_data=$DATASET_PATH
--dev_dataloader=beir
--do_dev=True
--quick_dev=False
--test_data=$DATASET_PATH
--test_dataloader=beir
--num_train_epochs=1
--eval_steps=0.4
--save_steps=0.4
"
export OUTPUT_DIR=$(python -m clddp.args.train $CLI_ARGS)
mkdir -p $OUTPUT_DIR
export LOG_PATH="$OUTPUT_DIR/logging.log"
echo "Logging file path: $LOG_PATH"
nohup torchrun --nproc_per_node=4 --master_port=29501 -m clddp.train $CLI_ARGS > $LOG_PATH &

Search Example

This will run exact search with multiple GPUs and output the retrieval results in the TREC format (each row is query-id Q0 passage-id rank score exp):

# Build a retriever checkpoint:
export CHECKPOINT_DIR="checkpoints/off-the-shell/tasb"
export CLI_ARGS="
--output_dir=$CHECKPOINT_DIR
--query_model_name_or_path="sentence-transformers/msmarco-distilbert-base-tas-b"
--shared_encoder=True
--sep=blank
--pooling=cls
--similarity_function=dot_product
--query_max_length=350
--passage_max_length=350
--sim_scale=1.0
"
python -m clddp.retriever $CLI_ARGS

# Run search:
dataset="fiqa"
export DATA_DIR="data"
export DATASET_PATH="$DATA_DIR/$dataset"
if [ ! -d $DATASET_PATH ]; then
    echo "Data path $DATASET_PATH does not exist"
    exit
fi
if [ ! -d $CHECKPOINT_DIR ]; then
    echo "Checkpoint path $CHECKPOINT_DIR does not exist"
    exit
fi
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export CLI_ARGS="
--checkpoint_dir=$CHECKPOINT_DIR
--data_dir=$DATASET_PATH
--dataloader=beir
"
export OUTPUT_DIR=$(python -m clddp.args.search $CLI_ARGS)
mkdir -p $OUTPUT_DIR
export LOG_PATH="$OUTPUT_DIR/logging.log"
echo "Logging file path: $LOG_PATH"
nohup torchrun --nproc_per_node=4 --master_port=29501 -m clddp.search $CLI_ARGS > $LOG_PATH &

Evaluation is similar and the example can be found here.

Custom Data

For loading custom data, one needs to inherit the clddp.dataloader.BaseDataLoader and add it to the lookup map clddp.DATA_LOADER_LOOKUP:

from clddp.train import main
from clddp.dataloader import BaseDataLoader, DATA_LOADER_LOOKUP
from clddp.dm import RetrievalDataset

class MyDataLoader(BaseDataLoader):
    def load_data(data_name_or_path: str, progress_bar: bool) -> RetrievalDataset:
        ...

DATA_LOADER_LOOKUP["my_dataloader"] = MyDataLoader

if __name__ == "__main__":
    main()  # Same for other entry points

Then, one can specify the xxx_dataloader=my_dataloader in the CLI arguments.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

clddp-0.0.8.tar.gz (28.6 kB view details)

Uploaded Source

Built Distribution

clddp-0.0.8-py3-none-any.whl (39.7 kB view details)

Uploaded Python 3

File details

Details for the file clddp-0.0.8.tar.gz.

File metadata

  • Download URL: clddp-0.0.8.tar.gz
  • Upload date:
  • Size: 28.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.18

File hashes

Hashes for clddp-0.0.8.tar.gz
Algorithm Hash digest
SHA256 e7295a3d55c0bc1a86609dfd7f9103e89eef3903380eba849532df2bf9f0cd60
MD5 8b0a87f91eeb8d286e6ed2874b123ffa
BLAKE2b-256 5ab42fe8dedaaee316950ee7f1d54b61b910c067071bcc7e5da32e8a3d178080

See more details on using hashes here.

File details

Details for the file clddp-0.0.8-py3-none-any.whl.

File metadata

  • Download URL: clddp-0.0.8-py3-none-any.whl
  • Upload date:
  • Size: 39.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.18

File hashes

Hashes for clddp-0.0.8-py3-none-any.whl
Algorithm Hash digest
SHA256 a4b4f335e653cf2b1661d567afef7bde3ac9c433e352be3715f88025744d7d8c
MD5 b80662e0b78e785a0778adcb865eddb8
BLAKE2b-256 a2bb09e79ef138ea2fc70c5d23a8e64899943ef43884fdac9ff683ed775ccf40

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page