Skip to main content

Simple dataset to dataloader library for pytorch

Project description

https://badge.fury.io/py/pytorch-datastream.svg https://img.shields.io/pypi/pyversions/pytorch-datastream.svg https://readthedocs.org/projects/pytorch-datastream/badge/?version=latest https://img.shields.io/pypi/l/pytorch-datastream.svg

This is a simple library for creating readable dataset pipelines and reusing best practices for issues such as imbalanced datasets. There are just two components to keep track of: Dataset and Datastream.

Dataset is a simple mapping between an index and an example. It provides pipelining of functions in a readable syntax originally adapted from tensorflow 2’s tf.data.Dataset.

Datastream combines a Dataset and a sampler into a stream of examples. It provides a simple solution to oversampling / stratification, weighted sampling, and finally converting to a torch.utils.data.DataLoader.

Install

poetry add pytorch-datastream

Or, for the old-timers:

pip install pytorch-datastream

Usage

The list below is meant to showcase functions that are useful in most standard and non-standard cases. It is not meant to be an exhaustive list. See the documentation for a more extensive list on API and usage.

Dataset.from_subscriptable
Dataset.from_dataframe
Dataset
    .map
    .subset
    .split
    .cache
    .with_columns

Datastream.merge
Datastream.zip
Datastream
    .map
    .data_loader
    .zip_index
    .update_weights_
    .update_example_weight_
    .weight
    .state_dict
    .load_state_dict

Simple image dataset example

Here’s a basic example of loading images from a directory:

from datastream import Dataset
from pathlib import Path
from PIL import Image

# Assuming images are in a directory structure like:
# images/
#   class1/
#     image1.jpg
#     image2.jpg
#   class2/
#     image3.jpg
#     image4.jpg

image_dir = Path("images")
image_paths = list(image_dir.glob("**/*.jpg"))

dataset = (
    Dataset.from_paths(image_paths, pattern=r".*/(?P<class_name>\w+)/(?P<image_name>\w+).jpg")
    .map(lambda row: dict(
        image=Image.open(row["path"]),
        class_name=row["class_name"],
        image_name=row["image_name"],
    ))
)

# Access an item from the dataset
first_item = dataset[0]
print(f"Class: {first_item['class_name']}, Image name: {first_item['image_name']}")

Merge / stratify / oversample datastreams

The fruit datastreams given below repeatedly yields the string of its fruit type.

>>> datastream = Datastream.merge([
...     (apple_datastream, 2),
...     (pear_datastream, 1),
...     (banana_datastream, 1),
... ])
>>> next(iter(datastream.data_loader(batch_size=8)))
['apple', 'apple', 'pear', 'banana', 'apple', 'apple', 'pear', 'banana']

Zip independently sampled datastreams

The fruit datastreams given below repeatedly yields the string of its fruit type.

>>> datastream = Datastream.zip([
...     apple_datastream,
...     Datastream.merge([pear_datastream, banana_datastream]),
... ])
>>> next(iter(datastream.data_loader(batch_size=4)))
[('apple', 'pear'), ('apple', 'banana'), ('apple', 'pear'), ('apple', 'banana')]

More usage examples

See the documentation for more usage examples.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_datastream-0.4.12.tar.gz (24.2 kB view details)

Uploaded Source

Built Distribution

pytorch_datastream-0.4.12-py3-none-any.whl (29.3 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_datastream-0.4.12.tar.gz.

File metadata

  • Download URL: pytorch_datastream-0.4.12.tar.gz
  • Upload date:
  • Size: 24.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.2.2 CPython/3.8.18 Linux/6.5.0-1025-azure

File hashes

Hashes for pytorch_datastream-0.4.12.tar.gz
Algorithm Hash digest
SHA256 db9c1da627a3f9d4583abfed3a650ad40931ea84f460a01b60a8cc00ef6f6782
MD5 608a46a91756b1a1646083123151394d
BLAKE2b-256 d255d9e4109f64f42c2a1b48aa28cb8e6ed57e7547b85657f42da325dcd3023b

See more details on using hashes here.

File details

Details for the file pytorch_datastream-0.4.12-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_datastream-0.4.12-py3-none-any.whl
Algorithm Hash digest
SHA256 c9b9c3aa5b7815b9a0bb3cd0059546cf1566575173d8306616d1c56a18684431
MD5 d8696743c51e5cb8b58e011ed7be07a7
BLAKE2b-256 2ff2dc41dc75af3fa3f0c31033466aea916dbd5f3822809bd41c0d2a412c2620

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page