Skip to main content

Simple dataset to dataloader library for pytorch

Project description

This is a simple library for creating very readable dataset pipelines and reusing best practices for issues such as imbalanced datasets. There are just two components to keep track of: Dataset and Datastream.

Dataset is a simple mapping between an index and an example. It provides pipelining of functions in a very readable syntax originally adapted from tensorflow 2’s tf.data.Dataset.

Datastream combines a Dataset and a sampler into a stream of examples. It provides a simple solution to oversampling / stratification, weighted sampling, and finally converting to a torch.utils.data.DataLoader.

Install

pip install pytorch-datastream

Usage

The list below is meant to showcase functions that are useful in most standard and non-standard cases. It is not meant to be an exhaustive list. See the documentation for a more extensive list on API and usage.

Dataset.from_subscriptable
Dataset.from_dataframe
Dataset
    .map
    .subset

Datastream.merge
Datastream.zip
Datastream
    .map
    .data_loader
    .zip_index
    .update_weights_
    .update_example_weight_
    .weight
    .state_dict
    .load_state_dict
    .multi_sample
    .sample_proportion

Dataset from subscriptable

from datastream import Dataset

fruits_and_cost = (
    ('apple', 5),
    ('pear', 7),
    ('banana', 14),
    ('kiwi', 100),
)

dataset = (
    Dataset.from_subscriptable(fruits_and_cost)
    .map(lambda fruit, cost: (
        fruit,
        cost * 2,
    ))
)

print(dataset[2]) # ('banana', 28)

Dataset from pandas dataframe

from PIL import Image
from imgaug import augmenters as iaa
from datastream import Dataset

augmenter = iaa.Sequential([...])

def preprocess(image, class_names):
    ...

dataset = (
    Dataset.from_dataframe(df)
    .map(lambda row: (
        row['image_path'],
        row['class_names'],
    ))
    .map(lambda image_path, class_names: (
        Image.open(image_path),
        class_names,
    ))
    .map(lambda image, class_names: (
        augmenter.augment(image=image),
        class_names,
    ))
    .map(preprocess)
)

Datastream to pytorch data loader

data_loader = (
    Datastream(dataset)
    .data_loader(
        batch_size=32,
        num_workers=8,
        n_batches_per_epoch=100,
    )
)

More usage examples

See the documentation for examples with oversampling / stratification and weighted sampling.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-datastream-0.2.2.tar.gz (22.7 kB view details)

Uploaded Source

File details

Details for the file pytorch-datastream-0.2.2.tar.gz.

File metadata

  • Download URL: pytorch-datastream-0.2.2.tar.gz
  • Upload date:
  • Size: 22.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3 requests-toolbelt/0.9.1 tqdm/4.46.1 CPython/3.8.0

File hashes

Hashes for pytorch-datastream-0.2.2.tar.gz
Algorithm Hash digest
SHA256 a15328b16f3350ecbfd72462b4f0463b2a8ae88b8e0a4821d5fe4053a5388d8d
MD5 8d9ac8d7978318d8daf433ef5484f2fb
BLAKE2b-256 7cd3cd7414da0c04ba6ebaef58ef772f76d04fcb3d7e94d63f7ff296b0e94556

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page