Skip to main content

PyTorch dataset loader for MNIST, Fashion-MNIST, EMNIST-Balanced, CIFAR10, SVHN, and MalImg datasets

Project description

PyTorch Datasets

License: AGPL v3 Python 3.7 Python 3.8

Overview

This repository is meant for easier and faster access to commonly used benchmark datasets. Using this repository, one can load the datasets in a ready-to-use fashion for PyTorch models. Additionally, this can be used to load the low-dimensional features of the aforementioned datasets, encoded using PCA, t-SNE, or UMAP.

Datasets

  • MNIST
  • Fashion-MNIST
  • EMNIST-Balanced
  • CIFAR10
  • SVHN
  • MalImg
  • AG News

Usage

It is recommended to use a virtual environment to isolate the project dependencies.

$ virtualenv env --python=python3  # we use python 3
$ pip install pt-datasets  # install the package

We use the tsnecuda library for the CUDA-accelerated t-SNE encoder, which can be installed by following the instructions in its wiki.

But there is also a provided script for installing tsne-cuda from source.

$ bash setup/install_tsnecuda

Do note that this script has only been tested on an Ubuntu 20.04 LTS system with Nvidia GTX960M GPU.

We can then use this package for loading ready-to-use data loaders,

from pt_datasets import load_dataset, create_dataloader

# load the training and test data
train_data, test_data = load_dataset(name="cifar10")

# create a data loader for the training data
train_loader = create_dataloader(
    dataset=train_data, batch_size=64, shuffle=True, num_workers=1
)

...

# use the data loader for training
model.fit(train_loader, epochs=10)

We can also encode the dataset features to a lower-dimensional space,

import seaborn as sns
import matplotlib.pyplot as plt
from pt_datasets import load_dataset, encode_features

# load the training and test data
train_data, test_data = load_dataset(name="fashion_mnist")

# get the numpy array of the features
# the encoders can only accept np.ndarray types
train_features = train_data.data.numpy()

# flatten the tensors
train_features = train_features.reshape(
    train_features.shape[0], -1
)

# get the labels
train_labels = train_data.targets.numpy()

# get the class names
classes = train_data.classes

# encode training features using t-SNE with CUDA
encoded_train_features = encode_features(
    features=train_features,
    seed=1024,
    use_cuda=True,
    encoder="tsne"
)

# use seaborn styling
sns.set_style("darkgrid")

# scatter plot each feature w.r.t class
for index in range(len(classes)):
    plt.scatter(
        encoded_train_features[train_labels == index, 0],
        encoded_train_features[train_labels == index, 1],
        label=classes[index],
        edgecolors="black"
    )
plt.legend(loc="upper center", title="Fashion-MNIST classes", ncol=5)
plt.show()

Citation

When using the Malware Image classification dataset, kindly use the following citations,

  • BibTex
@article{agarap2017towards,
    title={Towards building an intelligent anti-malware system: a deep learning approach using support vector machine (SVM) for malware classification},
    author={Agarap, Abien Fred},
    journal={arXiv preprint arXiv:1801.00318},
    year={2017}
}
  • MLA
Agarap, Abien Fred. "Towards building an intelligent anti-malware system: a
deep learning approach using support vector machine (svm) for malware
classification." arXiv preprint arXiv:1801.00318 (2017).

License

PyTorch Datasets utility repository
Copyright (C) 2020  Abien Fred Agarap

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pt-datasets-0.3.0.tar.gz (7.7 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page