Skip to main content

Library for loading PyTorch datasets and data loaders.

Project description

PyTorch Datasets

PyPI version License: AGPL v3 Python 3.9

Overview

This repository is meant for easier and faster access to commonly used benchmark datasets. Using this repository, one can load the datasets in a ready-to-use fashion for PyTorch models. Additionally, this can be used to load the low-dimensional features of the aforementioned datasets, encoded using PCA, t-SNE, or UMAP.

Datasets

Note on COVID19 datasets: Training models on this is not intended to produce models for direct clinical diagnosis. Please do not use the model output for self-diagnosis, and seek help from your local health authorities.

Usage

It is recommended to use a virtual environment to isolate the project dependencies.

$ virtualenv env --python=python3  # we use python 3
$ pip install pt-datasets  # install the package

We can then use this package for loading ready-to-use data loaders,

from pt_datasets import load_dataset, create_dataloader

# load the training and test data
train_data, test_data = load_dataset(name="cifar10")

# create a data loader for the training data
train_loader = create_dataloader(
    dataset=train_data, batch_size=64, shuffle=True, num_workers=1
)

...

# use the data loader for training
model.fit(train_loader, epochs=10)

We can also encode the dataset features to a lower-dimensional space,

import seaborn as sns
import matplotlib.pyplot as plt
from pt_datasets import load_dataset, encode_features

# load the training and test data
train_data, test_data = load_dataset(name="fashion_mnist")

# get the numpy array of the features
# the encoders can only accept np.ndarray types
train_features = train_data.data.numpy()

# flatten the tensors
train_features = train_features.reshape(
    train_features.shape[0], -1
)

# get the labels
train_labels = train_data.targets.numpy()

# get the class names
classes = train_data.classes

# encode training features using t-SNE
encoded_train_features = encode_features(
    features=train_features,
    seed=1024,
    encoder="tsne"
)

# use seaborn styling
sns.set_style("darkgrid")

# scatter plot each feature w.r.t class
for index in range(len(classes)):
    plt.scatter(
        encoded_train_features[train_labels == index, 0],
        encoded_train_features[train_labels == index, 1],
        label=classes[index],
        edgecolors="black"
    )
plt.legend(loc="upper center", title="Fashion-MNIST classes", ncol=5)
plt.show()

Citation

When using the Malware Image classification dataset, kindly use the following citations,

  • BibTex
@article{agarap2017towards,
    title={Towards building an intelligent anti-malware system: a deep learning approach using support vector machine (SVM) for malware classification},
    author={Agarap, Abien Fred},
    journal={arXiv preprint arXiv:1801.00318},
    year={2017}
}
  • MLA
Agarap, Abien Fred. "Towards building an intelligent anti-malware system: a
deep learning approach using support vector machine (svm) for malware
classification." arXiv preprint arXiv:1801.00318 (2017).

If you use this library, kindly cite it as,

@misc{agarap2020pytorch,
    author       = "Abien Fred Agarap",
    title        = "{PyTorch} datasets",
    howpublished = "\url{https://gitlab.com/afagarap/pt-datasets}",
    note         = "Accessed: 20xx-xx-xx"
}

License

PyTorch Datasets utility repository
Copyright (C) 2020-2023  Abien Fred Agarap

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pt_datasets-0.20.15.tar.gz (29.9 kB view hashes)

Uploaded Source

Built Distribution

pt_datasets-0.20.15-py3-none-any.whl (44.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page