Simple dataset to dataloader library for pytorch
Project description
This is a simple library for creating readable dataset pipelines and reusing best practices for issues such as imbalanced datasets. There are just two components to keep track of: Dataset and Datastream.
Dataset is a simple mapping between an index and an example. It provides pipelining of functions in a readable syntax originally adapted from tensorflow 2’s tf.data.Dataset.
Datastream combines a Dataset and a sampler into a stream of examples. It provides a simple solution to oversampling / stratification, weighted sampling, and finally converting to a torch.utils.data.DataLoader.
Install
poetry add pytorch-datastream
Or, for the old-timers:
pip install pytorch-datastream
Usage
The list below is meant to showcase functions that are useful in most standard and non-standard cases. It is not meant to be an exhaustive list. See the documentation for a more extensive list on API and usage.
Dataset.from_subscriptable
Dataset.from_dataframe
Dataset
.map
.subset
.split
.cache
.with_columns
Datastream.merge
Datastream.zip
Datastream
.map
.data_loader
.zip_index
.update_weights_
.update_example_weight_
.weight
.state_dict
.load_state_dict
Simple image dataset example
Here’s a basic example of loading images from a directory:
from datastream import Dataset
from pathlib import Path
from PIL import Image
# Assuming images are in a directory structure like:
# images/
# class1/
# image1.jpg
# image2.jpg
# class2/
# image3.jpg
# image4.jpg
image_dir = Path("images")
image_paths = list(image_dir.glob("**/*.jpg"))
dataset = (
Dataset.from_paths(image_paths, pattern=r".*/(?P<class_name>\w+)/(?P<image_name>\w+).jpg")
.map(lambda row: dict(
image=Image.open(row["path"]),
class_name=row["class_name"],
image_name=row["image_name"],
))
)
# Access an item from the dataset
first_item = dataset[0]
print(f"Class: {first_item['class_name']}, Image name: {first_item['image_name']}")
Merge / stratify / oversample datastreams
The fruit datastreams given below repeatedly yields the string of its fruit type.
>>> datastream = Datastream.merge([
... (apple_datastream, 2),
... (pear_datastream, 1),
... (banana_datastream, 1),
... ])
>>> next(iter(datastream.data_loader(batch_size=8)))
['apple', 'apple', 'pear', 'banana', 'apple', 'apple', 'pear', 'banana']
Zip independently sampled datastreams
The fruit datastreams given below repeatedly yields the string of its fruit type.
>>> datastream = Datastream.zip([
... apple_datastream,
... Datastream.merge([pear_datastream, banana_datastream]),
... ])
>>> next(iter(datastream.data_loader(batch_size=4)))
[('apple', 'pear'), ('apple', 'banana'), ('apple', 'pear'), ('apple', 'banana')]
More usage examples
See the documentation for more usage examples.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for pytorch_datastream-0.4.12.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | db9c1da627a3f9d4583abfed3a650ad40931ea84f460a01b60a8cc00ef6f6782 |
|
MD5 | 608a46a91756b1a1646083123151394d |
|
BLAKE2b-256 | d255d9e4109f64f42c2a1b48aa28cb8e6ed57e7547b85657f42da325dcd3023b |
Hashes for pytorch_datastream-0.4.12-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c9b9c3aa5b7815b9a0bb3cd0059546cf1566575173d8306616d1c56a18684431 |
|
MD5 | d8696743c51e5cb8b58e011ed7be07a7 |
|
BLAKE2b-256 | 2ff2dc41dc75af3fa3f0c31033466aea916dbd5f3822809bd41c0d2a412c2620 |