Skip to main content

PyTorch samplers that output roughly balanced batches with support for multilabel datasets.

Project description

PyTorch Multilabel Balanced Samplers

This package provides samplers to fetch data samples from multilabel datasets in a balanced manner. Balanced sampling from multilabel datasets can be especially useful to handle class imbalance issues.

Samplers

  • BaseMultilabelBalancedRandomSampler: This is the base class for all the provided samplers. It initializes the basic structure required for sampling, such as class indices.

  • RandomClassSampler: This sampler randomly chooses a class and then picks a random example from that class.

  • ClassCycleSampler: As the name suggests, it cycles through each class and fetches a random example from the current class.

  • LeastSampledClassSampler: Chooses the class with the least number of samples fetched so far and retrieves a random example from that class.

Usage

Installation:

This package is installable via pip:

pip install pytorch-multilabel-balanced-sampler

Initialization:

For all samplers, the initialization arguments are:

  • labels: A 2D tensor of shape (n_examples, n_classes) containing the one-hot encoded labels for the dataset.
  • indices: A sequence of integers representing the indices of the dataset. Default is the range of the dataset size.
from pytorch_multilabel_balanced_sampler.samplers import RandomClassSampler, ClassCycleSampler, LeastSampledClassSampler

sampler1 = RandomClassSampler(labels=my_labels, indices=my_indices)
sampler2 = ClassCycleSampler(labels=my_labels)
sampler3 = LeastSampledClassSampler(labels=my_labels, indices=my_indices)

Fetching samples:

Iterate over the sampler object to fetch samples:

for sample in sampler1:
    print(sample)

Note:

All samplers are inherited from BaseMultilabelBalancedRandomSampler, which in turn inherits from PyTorch's Sampler class. This ensures compatibility with PyTorch's data loading utilities.

License

The MIT License (MIT). License

Feedback & Issues

For feedback, issues, or feature requests, please raise an issue on the GitHub repository.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

Built Distribution

File details

Details for the file pytorch_multilabel_balanced_sampler-0.1.3.tar.gz.

File metadata

File hashes

Hashes for pytorch_multilabel_balanced_sampler-0.1.3.tar.gz
Algorithm Hash digest
SHA256 06a52a5106f9cabd2a7630b5d606119a7956cce009e8e891a753a311037b1357
MD5 57d94a5b18f687f221f70fb991ff631c
BLAKE2b-256 347957965267d8e0f27a93501d30a74d9a2c954918182eb8d8a587c7b667445d

See more details on using hashes here.

File details

Details for the file pytorch_multilabel_balanced_sampler-0.1.3-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_multilabel_balanced_sampler-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 fb627bd394dc7e50d21f12f09998131a92bab11ea90e00310ac36aa051663887
MD5 042387e94b3630661551017395d05a6a
BLAKE2b-256 ff1b48b250f91a0d02b7f56dc59a4fe7f4808144affbd13a2d2f30a04f7edbf7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page