Skip to main content

A (PyTorch) imbalanced dataset sampler for oversampling low classesand undersampling high frequent ones.

Project description

Imbalanced Dataset Sampler

license

Introduction

In many machine learning applications, we often come across datasets where some types of data may be seen more than other types. Take identification of rare diseases for example, there are probably more normal samples than disease ones. In these cases, we need to make sure that the trained model is not biased towards the class that has more data. As an example, consider a dataset where there are 5 disease images and 20 normal images. If the model predicts all images to be normal, its accuracy is 80%, and F1-score of such a model is 0.88. Therefore, the model has high tendency to be biased toward the ‘normal’ class.

To solve this problem, a widely adopted technique is called resampling. It consists of removing samples from the majority class (under-sampling) and / or adding more examples from the minority class (over-sampling). Despite the advantage of balancing classes, these techniques also have their weaknesses (there is no free lunch). The simplest implementation of over-sampling is to duplicate random records from the minority class, which can cause overfitting. In under-sampling, the simplest technique involves removing random records from the majority class, which can cause loss of information.

resampling

In this repo, we implement an easy-to-use PyTorch sampler ImbalancedDatasetSampler that is able to

  • rebalance the class distributions when sampling from the imbalanced dataset
  • estimate the sampling weights automatically
  • avoid creating a new balanced dataset
  • mitigate overfitting when it is used in conjunction with data augmentation techniques

Usage

For a simple start install the package via one of following ways:

pip install torchsampler

Simply pass an ImbalancedDatasetSampler for the parameter sampler when creating a DataLoader. For example:

from torchsampler import ImbalancedDatasetSampler

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    sampler=ImbalancedDatasetSampler(train_dataset),
    batch_size=args.batch_size,
    **kwargs
)

Then in each epoch, the loader will sample the entire dataset and weigh your samples inversely to your class appearing probability.

Example: Imbalanced MNIST Dataset

Distribution of classes in the imbalanced dataset:

With Imbalanced Dataset Sampler:

(left: test acc in each epoch; right: confusion matrix)

Without Imbalanced Dataset Sampler:

(left: test acc in each epoch; right: confusion matrix)

Note that there are significant improvements for minor classes such as 2 6 9, while the accuracy of the other classes is preserved.

Contributing

We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion. If you plan to contribute new features, utility functions or extensions, please first open an issue and discuss the feature with us.

Licensing

MIT licensed.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchsampler-0.1.2.tar.gz (6.6 kB view details)

Uploaded Source

Built Distribution

torchsampler-0.1.2-py3-none-any.whl (5.6 kB view details)

Uploaded Python 3

File details

Details for the file torchsampler-0.1.2.tar.gz.

File metadata

  • Download URL: torchsampler-0.1.2.tar.gz
  • Upload date:
  • Size: 6.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.13

File hashes

Hashes for torchsampler-0.1.2.tar.gz
Algorithm Hash digest
SHA256 6503acf0ff76888905595006a45c2a2fa017f7a13fffbfcddb3827ef8226ea78
MD5 85e3dba094bee795ea338cfaf7078156
BLAKE2b-256 12c0f4fd80e4a8ce5b698d85d037fb3259e1d7cb415d9d85c7f336a1c15a55fe

See more details on using hashes here.

File details

Details for the file torchsampler-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for torchsampler-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a77a1d4bb7d5f134b31ca4aff4de9647d9aa0f416a62e315a403f253ebaf244f
MD5 2fb2ccdc31c4437244db838043df7fec
BLAKE2b-256 152b74ff8086106f27dbe7eec0ce680476785d79fd6f934f5c474ce7b597d1fc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page