Skip to main content

Library for loading PyTorch datasets and data loaders.

Project description

PyTorch Datasets

PyPI version License: AGPL v3 Python 3.9

Overview

This repository is meant for easier and faster access to commonly used benchmark datasets. Using this repository, one can load the datasets in a ready-to-use fashion for PyTorch models. Additionally, this can be used to load the low-dimensional features of the aforementioned datasets, encoded using PCA, t-SNE, or UMAP.

Datasets

Note on COVID19 datasets: Training models on this is not intended to produce models for direct clinical diagnosis. Please do not use the model output for self-diagnosis, and seek help from your local health authorities.

Usage

It is recommended to use a virtual environment to isolate the project dependencies.

$ virtualenv env --python=python3  # we use python 3
$ pip install pt-datasets  # install the package

We can then use this package for loading ready-to-use data loaders,

from pt_datasets import load_dataset, create_dataloader

# load the training and test data
train_data, test_data = load_dataset(name="cifar10")

# create a data loader for the training data
train_loader = create_dataloader(
    dataset=train_data, batch_size=64, shuffle=True, num_workers=1
)

...

# use the data loader for training
model.fit(train_loader, epochs=10)

We can also encode the dataset features to a lower-dimensional space,

import seaborn as sns
import matplotlib.pyplot as plt
from pt_datasets import load_dataset, encode_features

# load the training and test data
train_data, test_data = load_dataset(name="fashion_mnist")

# get the numpy array of the features
# the encoders can only accept np.ndarray types
train_features = train_data.data.numpy()

# flatten the tensors
train_features = train_features.reshape(
    train_features.shape[0], -1
)

# get the labels
train_labels = train_data.targets.numpy()

# get the class names
classes = train_data.classes

# encode training features using t-SNE
encoded_train_features = encode_features(
    features=train_features,
    seed=1024,
    encoder="tsne"
)

# use seaborn styling
sns.set_style("darkgrid")

# scatter plot each feature w.r.t class
for index in range(len(classes)):
    plt.scatter(
        encoded_train_features[train_labels == index, 0],
        encoded_train_features[train_labels == index, 1],
        label=classes[index],
        edgecolors="black"
    )
plt.legend(loc="upper center", title="Fashion-MNIST classes", ncol=5)
plt.show()

Citation

When using the Malware Image classification dataset, kindly use the following citations,

  • BibTex
@article{agarap2017towards,
    title={Towards building an intelligent anti-malware system: a deep learning approach using support vector machine (SVM) for malware classification},
    author={Agarap, Abien Fred},
    journal={arXiv preprint arXiv:1801.00318},
    year={2017}
}
  • MLA
Agarap, Abien Fred. "Towards building an intelligent anti-malware system: a
deep learning approach using support vector machine (svm) for malware
classification." arXiv preprint arXiv:1801.00318 (2017).

If you use this library, kindly cite it as,

@misc{agarap2020pytorch,
    author       = "Abien Fred Agarap",
    title        = "{PyTorch} datasets",
    howpublished = "\url{https://gitlab.com/afagarap/pt-datasets}",
    note         = "Accessed: 20xx-xx-xx"
}

License

PyTorch Datasets utility repository
Copyright (C) 2020-2023  Abien Fred Agarap

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pt_datasets-0.20.15.tar.gz (29.9 kB view details)

Uploaded Source

Built Distribution

pt_datasets-0.20.15-py3-none-any.whl (44.9 kB view details)

Uploaded Python 3

File details

Details for the file pt_datasets-0.20.15.tar.gz.

File metadata

  • Download URL: pt_datasets-0.20.15.tar.gz
  • Upload date:
  • Size: 29.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.0 CPython/3.11.7 Linux/6.7.4-arch1-1.1-g14

File hashes

Hashes for pt_datasets-0.20.15.tar.gz
Algorithm Hash digest
SHA256 11c0fac4161583be55f3da8f493dfc54dbbebb25f8b80ef14b573eb3b7a22017
MD5 45068c84dd37ec99d6c9ef5c50eda8b9
BLAKE2b-256 dcc34ba5159191a8e94888beaed8b8071e5f1a250c745cac332bc6dd565ac5c3

See more details on using hashes here.

File details

Details for the file pt_datasets-0.20.15-py3-none-any.whl.

File metadata

  • Download URL: pt_datasets-0.20.15-py3-none-any.whl
  • Upload date:
  • Size: 44.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.0 CPython/3.11.7 Linux/6.7.4-arch1-1.1-g14

File hashes

Hashes for pt_datasets-0.20.15-py3-none-any.whl
Algorithm Hash digest
SHA256 9cac1fa0eeb6c8f6f214ff82add812d9cca36480d68d9a67c76e06cb249736e2
MD5 2898e4de11b34c4f9bfed6e3d687ff34
BLAKE2b-256 8ef3756f54e2bf34c277028750cd8c8b0ebdd09ee152bbc4abd1678c0c6dd018

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page