Skip to main content

Pytorch dataset with Zarr backend

Project description

Zarr Torch Dataset

What is it?

Zarr Torch Dataset is an Pytorch Iterable Dataset implementation based on Zarr arrays. This library addresses the problem of data distribution in the context of distributed data parallel PyTorch training.

During such training, several processes are executed at the same time. Ideally, each process holds a portion of the training data without it being duplicated in another process. In addition, PyTorch offers parallel production of batches within a process: workers prepare batches that will be consumed one by one during training. Thus, workers within a training process must also share the data without duplication.

While distributed training (DDP), worker creation and management (DataLoader) mechanisms are implemented by PyTorch, loading data into RAM is left to the user, as the PyTorch makes no assumptions about the nature of the training data. To be precise: PyTorch does not distribute training data, but offers a data sampling mechanism based on data indexes (Samplers).

In general, when the training dataset is relatively small, users do not bother to optimize the memory footprint of training processes: each worker in the processes loads the entire dataset into RAM. This behavior poses a memory size problem when the dataset is large: it becomes impossible to load it multiple times. Furthermore, if the dataset does not fit into the memory of a single compute node, it must be distributed across a cluster of nodes.

This is where Zarr Torch Dataset comes in.

Like the Webdataset library, Zarr Torch Dataset will distribute pieces of the training dataset across all workers in all training process on all compute nodes, without duplication and with consistent data shuffling. This library is based on the Zarr data format, which implements the concept of data chunks with configurable dimensions. Zarr Torch Dataset enables the selection of a subset of chunks for training, which is useful for minimizing the time spent loading data into memory when developing the training script.

[!important] There is only one constraint on chunk dimensions: chunking must be done only on the first dimension. For example, if the dataset has dimensions (1400, 180, 360, 8), then the chunk dimensions must be (n, 180, 360, 8) where n < 1400. The greater the number of chunks (not n, but 1400 / n), the greater the shuffling entropy, but the longer the read access time.

Why the Zarr format?

  • It is mature and commonly used in geosciences.
  • It allows decoupling between data storage (stores) and data indexing.
  • It offers a set of compression codecs wrapped in the Blosc meta compressor.
  • A selection of data from a Zarr array is of type Numpy NDArray.

[!important] Zarr Torch Dataset only supports Zarr datasets and not Xarray datasets persisted in Zarr. The library doesn't support Zarr shards: it distributes chunks not shards.

Notes

  • reload_dataloaders_every_n_epochs in PyTorch Lightning trainer should always be zero unless you want to reload the entire dataset.
  • use_distributed_sampler in PyTorch Lightning trainer is False for the best.
  • You can provide samples and targets as Zarr groups.

How to install?

On IPSL GPU cluster HAL

Installation

module load pytorch-lightning/2.5.0.post0 # Or latest version.
mkdir -p "virtual_envs" # Create the environment parent dir, if it is not already done.
python -m venv --system-site-packages "virtual_envs/zarr_torch_dataset" # Create the virtual environment named zarr_torch_dataset.
source "virtual_envs/zarr_torch_dataset/bin/activate" # Activate the virtual environment.
pip install --no-cache-dir -U pip
pip install --no-cache-dir -U zarr-torch-dataset

Activation

module load pytorch-lightning/2.5.0.post0
source "virtual_envs/zarr_torch_dataset/bin/activate"

On IDRIS super computer Jean Zay

Installation

module load pytorch-gpu/py3/2.8.0 # Or latest version
mkdir -p "virtual_envs" # Create the environment parent dir, if it is not already done.
python -m venv --system-site-packages "virtual_envs/zarr_torch_dataset" # Create the virtual environment named zarr_torch_dataset.
source "virtual_envs/zarr_torch_dataset/bin/activate" # Activate the virtual environment.
pip install --no-cache-dir-U pip
pip install --no-cache-dir -U lightning zarr-torch-dataset

Activation

module load pytorch-gpu/py3/2.8.0
source "virtual_envs/zarr_torch_dataset/bin/activate"

How to use?

Pytorch

TODO: to be documented.

Pytorch lightning

On IDRIS super computer Jean Zay

Example of a distributed data parallel training:
from zarrdataset import ZarrDataLoader, ZarrIterableDataset

# [...]

# Open your Zarr datasets.
train_samples = zarr.open('train_samples_zarr', mode='r')
train_targets = zarr.open('train_targets_zarr', mode='r')
validation_samples = zarr.open('validation_samples_zarr', mode='r')
validation_targets = zarr.open('validation_targets', mode='r')

# Instantiate the Zarr Iterable Datasets and Zarr Dataloaders.
# chunks_shuffle_seed can't be None, as all training process must have the same seed value.
train_ds = ZarrIterableDataset(samples=train_samples, targets=train_targets, shuffle_chunks=True,
                               chunks_shuffle_seed=42, shuffle_buffer=True, buffer_shuffle_seed=None)

val_ds = ZarrIterableDataset(samples=validation_samples, targets=validation_targets,
                             shuffle_chunks=False,
                             shuffle_buffer=False)

train_da = ZarrDataLoader(dataset=train_ds, num_workers=2, batch_size=32,
                          pin_memory=True, prefetch_factor=2, drop_last=False)

val_da = ZarrDataLoader(dataset=val_ds, num_workers=2, batch_size=32,
                        pin_memory=True, prefetch_factor=2, drop_last=False)

trainer = lightning.Trainer(devices=int(os.environ['SLURM_GPUS_ON_NODE']),
                            num_nodes=int(os.environ['SLURM_NNODES']),
                            strategy='ddp')

# Start training.
trainer.fit(model=model, train_dataloaders=train_da, val_dataloaders=val_da)
Example of a Slurm script submitting a 4 GPUs batch job (1 process per GPU), on partition gpu_p2:
#!/bin/bash

#SBATCH --account=YOUR_ACCOUNT
#SBATCH --partition=gpu_p2
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=4  # Max 8 per gpu_p2 nodes
#SBATCH --gres=gpu:4         # Max 8 per gpu_p2 nodes
#SBATCH --cpus-per-task=3    # Specific for gpu_p2 nodes
#SBATCH --hint=nomultithread # Specific for Jean Zay
#SBATCH --time=2:00:00
#SBATCH --job-name=pl_test
#SBATCH --error=pl_test.log
#SBATCH --output=pl_test.log

# Enable the standard and error outputs of Python.
export PYTHONUNBUFFERED=1

# The latest pytorch module at this time.
module load pytorch-gpu/py3/2.8.0

# Activate Python virtual env.
source "/SOME/WHERE/TO/YOUR/virtual_envs/zarr_torch_dataset/bin/activate"

# Avoid no space left in device on Jean Zay.
export TMPDIR=$JOBSCRATCH

# You can give chunks_shuffle_seed value as a parameter to your script thank to $@.
# $RANDOM should be resourceful for generating the seed.
srun python YOUR_PYTHON_SCRIPT.py $@

returned_code=$?
echo "> script completed with exit code ${returned_code}"
exit ${returned_code}

[!tip] Pytorch and Pytorch Lightning Slurm scripts differ only from the way to allocate GPUs and tasks.

How to contribute?

Developing your patch in a forked repository then submitting a Merge Request in this repository.

  • Pip
pip install -e .
pip install pre-commit
pre-commit install
  • UV
uv sync
uv run pre-commit install

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zarr_torch_dataset-0.0.9.tar.gz (151.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

zarr_torch_dataset-0.0.9-py3-none-any.whl (29.5 kB view details)

Uploaded Python 3

File details

Details for the file zarr_torch_dataset-0.0.9.tar.gz.

File metadata

  • Download URL: zarr_torch_dataset-0.0.9.tar.gz
  • Upload date:
  • Size: 151.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.17 {"installer":{"name":"uv","version":"0.9.17","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for zarr_torch_dataset-0.0.9.tar.gz
Algorithm Hash digest
SHA256 0a9f656d102825322de38c1daa377e21cbabb4d21c7bbc7e4d9fa4458398cc07
MD5 879a0900e3ef72623ad34c2019c19076
BLAKE2b-256 bcd1946b79475d460a835af33bbd6fd7526fc20540d1cd9cd040313c3dfa99f3

See more details on using hashes here.

File details

Details for the file zarr_torch_dataset-0.0.9-py3-none-any.whl.

File metadata

  • Download URL: zarr_torch_dataset-0.0.9-py3-none-any.whl
  • Upload date:
  • Size: 29.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.17 {"installer":{"name":"uv","version":"0.9.17","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for zarr_torch_dataset-0.0.9-py3-none-any.whl
Algorithm Hash digest
SHA256 2298248503d0d7e3cf9c678341682ed9c9c0c8cae22702072ab55bc36849e19b
MD5 00574d67aebbf08818b2a43e841fe794
BLAKE2b-256 21e7cf47e4812fd21175a7eabcb6e89977b6f0f66898c53e08f53ff9f38d0a44

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page