Library for loading PyTorch datasets and data loaders.
Project description
PyTorch Datasets
Overview
This repository is meant for easier and faster access to commonly used benchmark datasets. Using this repository, one can load the datasets in a ready-to-use fashion for PyTorch models. Additionally, this can be used to load the low-dimensional features of the aforementioned datasets, encoded using PCA, t-SNE, or UMAP.
Datasets
- MNIST
- Fashion-MNIST
- EMNIST-Balanced
- CIFAR10
- SVHN
- MalImg
- AG News
- IMDB
- Yelp
- 20 Newsgroups
- KMNIST
- Wisconsin Diagnostic Breast Cancer
- COVID19 binary classification
- COVID19 multi-classification
Note on COVID19 datasets: Training models on this is not intended to produce models for direct clinical diagnosis. Please do not use the model output for self-diagnosis, and seek help from your local health authorities.
Usage
It is recommended to use a virtual environment to isolate the project dependencies.
$ virtualenv env --python=python3 # we use python 3
$ pip install pt-datasets # install the package
We can then use this package for loading ready-to-use data loaders,
from pt_datasets import load_dataset, create_dataloader
# load the training and test data
train_data, test_data = load_dataset(name="cifar10")
# create a data loader for the training data
train_loader = create_dataloader(
dataset=train_data, batch_size=64, shuffle=True, num_workers=1
)
...
# use the data loader for training
model.fit(train_loader, epochs=10)
We can also encode the dataset features to a lower-dimensional space,
import seaborn as sns
import matplotlib.pyplot as plt
from pt_datasets import load_dataset, encode_features
# load the training and test data
train_data, test_data = load_dataset(name="fashion_mnist")
# get the numpy array of the features
# the encoders can only accept np.ndarray types
train_features = train_data.data.numpy()
# flatten the tensors
train_features = train_features.reshape(
train_features.shape[0], -1
)
# get the labels
train_labels = train_data.targets.numpy()
# get the class names
classes = train_data.classes
# encode training features using t-SNE
encoded_train_features = encode_features(
features=train_features,
seed=1024,
encoder="tsne"
)
# use seaborn styling
sns.set_style("darkgrid")
# scatter plot each feature w.r.t class
for index in range(len(classes)):
plt.scatter(
encoded_train_features[train_labels == index, 0],
encoded_train_features[train_labels == index, 1],
label=classes[index],
edgecolors="black"
)
plt.legend(loc="upper center", title="Fashion-MNIST classes", ncol=5)
plt.show()
Citation
When using the Malware Image classification dataset, kindly use the following citations,
- BibTex
@article{agarap2017towards,
title={Towards building an intelligent anti-malware system: a deep learning approach using support vector machine (SVM) for malware classification},
author={Agarap, Abien Fred},
journal={arXiv preprint arXiv:1801.00318},
year={2017}
}
- MLA
Agarap, Abien Fred. "Towards building an intelligent anti-malware system: a
deep learning approach using support vector machine (svm) for malware
classification." arXiv preprint arXiv:1801.00318 (2017).
If you use this library, kindly cite it as,
@misc{agarap2020pytorch,
author = "Abien Fred Agarap",
title = "{PyTorch} datasets",
howpublished = "\url{https://gitlab.com/afagarap/pt-datasets}",
note = "Accessed: 20xx-xx-xx"
}
License
PyTorch Datasets utility repository
Copyright (C) 2020-2023 Abien Fred Agarap
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pt_datasets-0.20.15.tar.gz
.
File metadata
- Download URL: pt_datasets-0.20.15.tar.gz
- Upload date:
- Size: 29.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.8.0 CPython/3.11.7 Linux/6.7.4-arch1-1.1-g14
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 11c0fac4161583be55f3da8f493dfc54dbbebb25f8b80ef14b573eb3b7a22017 |
|
MD5 | 45068c84dd37ec99d6c9ef5c50eda8b9 |
|
BLAKE2b-256 | dcc34ba5159191a8e94888beaed8b8071e5f1a250c745cac332bc6dd565ac5c3 |
File details
Details for the file pt_datasets-0.20.15-py3-none-any.whl
.
File metadata
- Download URL: pt_datasets-0.20.15-py3-none-any.whl
- Upload date:
- Size: 44.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.8.0 CPython/3.11.7 Linux/6.7.4-arch1-1.1-g14
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9cac1fa0eeb6c8f6f214ff82add812d9cca36480d68d9a67c76e06cb249736e2 |
|
MD5 | 2898e4de11b34c4f9bfed6e3d687ff34 |
|
BLAKE2b-256 | 8ef3756f54e2bf34c277028750cd8c8b0ebdd09ee152bbc4abd1678c0c6dd018 |