Skip to main content

pytorch dataset wrappers for in-memory caching

Project description

KappaData

publish

Utilities for datasets and dataloading with pytorch

  • modular datasets
  • caching datasets in-memory
  • additional transforms
    • allow deterministic augmentations (e.g. for calculating the test loss on augmented samples)
    • RandAugment
    • patchwise augmentations
  • various dataset filters and other dataset manipulation
    • filter by class
    • limit size to a %
    • Mixup
    • Cutmix
    • label smoothing
    • ...
  • repeated augmentations

Modular datasets

pytorch datasets load all data in the __getitem__. KappaData decouples the __getitem__ such that single properties of the dataset can be loaded independently.

Image classification dataset example

Let's take an image classification dataset as an example. A sample consists of an image with an associated class label.

class ImageClassificationDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths):
        super().__init__()
        self.image_paths = image_paths
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img = load_image(self.image_paths[idx])
        class_label = image_path_to_class_label(self.image_paths[idx])
        return img, class_label

If your training process contains something that only requires the class labels, the dataset has to additionally load all the images which can take a long time (whereas loading only labels is very fast). With KappaData the __getitem__ method is split into subparts:

# inherit from kappadata.KDDataset
class ImageClassificationDataset(kappadata.KDDataset):
    def __init__(self, image_paths):
        super().__init__()
        self.image_paths = image_paths
    def __len__(self):
        return len(self.image_paths)
    
    # replace __getitem__ with getitem_x and getitem_y
    def getitem_x(self, idx, ctx=None):
        return load_image(self.image_paths[idx])
    def getitem_y(self, idx, ctx=None):
        return image_path_to_class_label(self.image_paths[idx])

Now each subpart of the dataset can be retrieved by wrapping the dataset into a ModeWrapper:

ds = ImageClassificationDataset(image_paths=...)
for y in kappadata.ModeWrapper(ds, mode="y"):
    ...
for x, y in kappadata.ModeWrapper(ds, mode="x y"):
    ...

torch.utils.data.Subset / torch.utils.data.ConcatDataset can be used by simply replacing them with kappadata.KDSubset/kappadata.KDConcatDataset.

Wrappers

"Dataset Wrappers"

KappaData implements various ways to manipulate datasets (kappadata.wrappers.dataset_wrappers).

  • Filter by class
    • kappadata.ClassFilterWrapper(ds, valid_classes=[0, 1])
    • kappadata.ClassFilterWrapper(ds, invalid_classes=[0, 1])
  • Balance data by oversampling underrepresented classes kappadata.OversamplingWrapper(ds)
  • Subset by specifying percentages
    • kappadata.PercentFilterWrapper(ds, from_percent=0.25)
    • kappadata.PercentFilterWrapper(ds, to_percent=0.75)
    • kappadata.PercentFilterWrapper(ds, from_percent=0.25, to_percent=0.75)
  • Repeat the whole dataset
    • repeat twice: kappadata.RepeatWrapper(ds, repetitions=2)
    • repeat until size is > 100 kappadata.RepeatWrapper(ds, min_size=100)
  • Shuffle dataset
    • kappadata.ShuffleWrapper(ds, seed=5)

"Sample Wrappers"

KappaData implements various ways to manipulate how samples are sampled from the underlying dataset (kappadata.wrappers.sample_wrappers). "Sample Wrappers" are similar to transforms in that they transform the sample in some way, but "Sample Wrappers" are more powerful because they have full access to the underlying dataset whereas normal transforms only have access to a single sample.

class Transform:
  def forward(x):
    # only x can be manipulated (e.g. normalized, image-transforms, ...)
class SampleWrapper(kd.KDWrapper):
  def getitem_x(idx, ctx=None):
    # access to the underlying dataset via self.dataset
    # e.g. return the sum of two different samples
    idx2 = np.random.randint(len(self))
    return self.dataset.getitem_x(idx, ctx) + self.dataset.getitem_x(idx2, ctx)

This allows implementing more complex transformations. KappaData implements the following SampleWrappers:

  • Mixup kappadata.MixupWrapper(dataset=ds, alpha=1., p=1.)
  • Cutmix kappadata.CutmixWrapper(dataset=ds, alpha=1., p=1.)
  • Mixup or Cutmix kappadata.MixWrapper(dataset=ds, cutmix_alpha=1., mixup_alpha=1., p=1., cutmix_p=0.5)
  • TODO sampling multiple views
  • label smoothing kappadata.LabelSmoothingWrapper(dataset=ds, smoothing=.1)

Augmentation parameters

With KappaData you can also retrieve various properties of your data prepocessing (e.g. augmentation parameters). The following example shows how you can retrieve the parameters of torchvision.transforms.RandomResizedCrop .

import torchvision.transforms.functional as F
class MyRandomResizedCrop(torchvision.transforms.RandomResizedCrop):
    def forward(self, img, ctx=None):
        # make random resized crop
        i, j, h, w = self.get_params(img, self.scale, self.ratio)
        cropped = F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
        # store parameters
        if ctx is not None:
          ctx["crop_parameters"] = (i, j, h, w)
        return cropped
  
class ImageClassificationDataset(kappadata.KDDataset):
    def __init__(self, ...):
      ...
      self.random_resized_crop = MyRandomResizedCrop()
    ...
    def getitem_x(self, idx, ctx=None):
        img = load_image(self.image_paths[idx])
        return self.random_resized_crop(img, ctx=ctx)

When you want to access the parameters simply pass return_ctx=True to the ModeWrapper:

ds = ImageClassificationDataset(image_paths=...)
for x, ctx in kappadata.ModeWrapper(ds, mode="x", return_ctx=True):
    print(ctx["crop_parameters"])
for (x, y), ctx in kappadata.ModeWrapper(ds, mode="x y", return_ctx=True):
    ...

Caching datasets in-memory

SharedDictDataset

kappadata.SharedDictDataset provides a wrapper to store arbitrary datasets in-memory via a dictionary shared between all worker processes (using python multiprocessing data structures). The shared memory part is important for dataloading with num_workers > 0. Small and medium sized datasets can be cached in-memory to avoid bottlenecks when loading data from a disk. For example even the full ImageNet can be cached on many servers as it has ~ 130GB and its not too uncommon for GPU servers to have more RAM than that.

Caching image datasets

Naively caching image datasets can lead to high memory consumption because image data is usually stored in a compressed format and decompressed during loading. To reduce memory, the raw uncompressed data needs to be cached.

Example caching a torchvision.datasets.ImageFolder:

from kappadata.loading.image_folder import raw_image_loader, raw_image_folder_sample_to_pil_sample 
class CachedImageFolder(kappadata.KDDataset):
    def __init__(self, ...):
        # modify ImageFolder to load raw samples (NOTE: can't apply transforms onto raw data)
        self.ds = torchvision.datasets.ImageFolder(..., transform=None, loader=raw_image_loader)
        # initialize cached dataset that decompresses the raw data into a PIL image
        self.cached_ds = kappadata.SharedDictDataset(self.ds, transform=raw_image_folder_sample_to_pil_sample)
        # store transforms to apply after decompression
        self.transform = ...
    def getitem_x(self, idx, ctx=None):
        x, _ = self.cached_ds[idx]
        if self.transform is not None:
            x = self.transform(x)
        return x

Automatically copy datasets to a local (fast) disk

Datasets are often stored on a global (slow) storage and before training moved to a local (fast) disk. kappadata.copy_folder_from_global_to_local provides an utility function to do this automatically:

  • local path doesn't exist -> automatically copy from global to local
  • local path exists -> do nothing
  • local path exists but is incomplete -> clear directory and copy again
from pathlib import Path
from kappadata import copy_folder_from_global_to_local
global_path = Path("/system/data/ImageNet")
local_path = Path("/local/data")
# /system/data/ImageNet contains a 'train' and a 'val' folder -> copy whole dataset
copy_folder_from_global_to_local(global_path, local_path)
# copy only "train"
copy_folder_from_global_to_local(global_path, local_path, relative_path="train")

The above code will also work (without modification) if /system/data/ImageNet contains only 2 zip files train.zip and val.zip

Miscellaneous

  • all datasets derived from kappadata.KDDataset automatically support python slicing
    • all_class_labels = ModeWrapper(ds, mode="y")[:]
    • all_class_labels = ModeWrapper(ds, mode="y")[5:-3:2]
  • all datasets derived from kappadata.KDDataset implement iter
    for y in ModeWrapper(ds, mode="y"):
        ...
    
  • retrieve the original dataset without wrappers ds.root_dataset

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kappadata-1.4.15.tar.gz (79.4 kB view details)

Uploaded Source

Built Distribution

kappadata-1.4.15-py3-none-any.whl (144.5 kB view details)

Uploaded Python 3

File details

Details for the file kappadata-1.4.15.tar.gz.

File metadata

  • Download URL: kappadata-1.4.15.tar.gz
  • Upload date:
  • Size: 79.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.4

File hashes

Hashes for kappadata-1.4.15.tar.gz
Algorithm Hash digest
SHA256 e1e4d1f49609a6b809bfce4f26d3c89cdedb64553474c0c224711fcc318d4660
MD5 98ea6d31ace361425ad0a200ecfceed9
BLAKE2b-256 cfcd7fff88ac0255c68eb4d7da60da13f918799ac82ea0b13b8d3af99f4114b1

See more details on using hashes here.

File details

Details for the file kappadata-1.4.15-py3-none-any.whl.

File metadata

  • Download URL: kappadata-1.4.15-py3-none-any.whl
  • Upload date:
  • Size: 144.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.4

File hashes

Hashes for kappadata-1.4.15-py3-none-any.whl
Algorithm Hash digest
SHA256 bb6ff380550598426746cfd31e892902d6bf4109f0a58d9930a7a7e8e56ff858
MD5 12a75afc7090512a847a81f3b28beb07
BLAKE2b-256 775eeb4b3c297fe6d5197655287057ef74de749c61f401cb7eec13224475113b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page