Skip to main content

A Dataloader using rpc-based workers

Project description

Documentation Continuous tests

RPC Dataloader

This library implements a variant of the PyTorch Dataloader using remote workers. This allows to distribute workers over remote servers rather than the one running the main script.

To use it, start one or several worker daemons on remote computers. The RPCDataloader on the main computer will dispatch requests for items to the workers and await the returned value.

Though similar to torch.rpc, this library uses its own implementation of RPC (Remote Procedure Call) which is simpler (no initialization) and does not conflict with the one from pytorch.

Installation

pip install git+https://github.com/CEA-LIST/RPCDataloader.git#egg=rpcdataloader

Usage

To use the RPC dataloader, start a few workers either from the command line:

python -m rpcdataloader.launch --host=0.0.0.0 --port=6543

or by calling rpcdataloader.run_worker directly from a python script.

Then instantiate the dataloader:

dataloader = rpcdataloader.RPCDataloader(
    workers=['node01:6543'],
    dataset=torchvision.datasets.FakeData,
    kwargs={'transform': torchvision.transforms.ToTensor()},
    batch_size=2,
    shuffle=True,
    pin_memory=True)

for minibatch in dataloader:
    ...

Slurm integration

Slurm integration is a little tricky as it relies on a rather exotic functionality: heterogeneous jobs. To distribute your workers on cpu nodes and your trainers on GPU nodes, use the following slurm script template:

#!/usr/bin/env sh
#SBATCH --time=3-00:00:00

#SBATCH --partition=gpu
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=2
#SBATCH --cpus-per-task=2
#SBATCH --mem=64G
#SBATCH --gres=gpu:2

#SBATCH hetjob

#SBATCH --partition=cpu
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=16
#SBATCH --cpus-per-task=2
#SBATCH --mem=72G

# create an output dir
export OUT_DIR="./outputs/${SLURM_JOB_NAME}.${SLURM_JOB_ID}"
mkdir -p $OUT_DIR

# start workers and collect host and port list
rm -f ${OUT_DIR}/workers && touch ${OUT_DIR}/workers
srun --het-group=1 -I --exclusive --kill-on-bad-exit=1 \
    sh -c '
        export port=$(( 16384 + $RANDOM % 49182 ))
        echo $(hostname):$port \
            | flock ${OUT_DIR}/workers tee -a ${OUT_DIR}/workers \
            &> /dev/null
        python -u -m rpcdataloader.launch --host=0.0.0.0 --port=$port
        ' &
worker_task_pid=$!

# block until all workers have written their address and port
tail -f ${OUT_DIR}/workers | head -n $SLURM_NTASKS_PER_NODE_HET_GROUP_1

# parse worker list
export workers=$(tr '\n' ' ' < ${OUT_DIR}/workers)

# run training script
export MASTER_ADDR="$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n1)"
export MASTER_PORT=$(( 16384 + $RANDOM % 49182 ))
srun --het-group=0 -I --exclusive --kill-on-bad-exit=1 \
    python -u example.py \
        --workers $workers

# stop workers
kill $worker_task_pid

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rpcdataloader-0.1.0.tar.gz (19.6 kB view details)

Uploaded Source

Built Distribution

rpcdataloader-0.1.0-py3-none-any.whl (18.5 kB view details)

Uploaded Python 3

File details

Details for the file rpcdataloader-0.1.0.tar.gz.

File metadata

  • Download URL: rpcdataloader-0.1.0.tar.gz
  • Upload date:
  • Size: 19.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for rpcdataloader-0.1.0.tar.gz
Algorithm Hash digest
SHA256 565fb610e2fa4e048fb37141a2ee5010bf71b50e525ad9d6bb770db8953d24e8
MD5 6ca7d058411471e4052857d5c2dc171d
BLAKE2b-256 16e37873d5b00ab87fd7058d5a9ae68ca1482749e9cd3165884864dc315123bf

See more details on using hashes here.

File details

Details for the file rpcdataloader-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for rpcdataloader-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 66a142628e9d20ca20e411a35936e75cefd66200a9f824c2192463fa64edb226
MD5 82a88207cc64924a14759c6d318981dc
BLAKE2b-256 894e921e04dd08ad579a24266ad4e42f45e749c22162e6c3dbae8ba0fd9badd9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page