PyTorch samplers that output roughly balanced batches with support for multilabel datasets.
Project description
PyTorch Multilabel Balanced Samplers
This package provides samplers to fetch data samples from multilabel datasets in a balanced manner. Balanced sampling from multilabel datasets can be especially useful to handle class imbalance issues.
Samplers
-
BaseMultilabelBalancedRandomSampler
: This is the base class for all the provided samplers. It initializes the basic structure required for sampling, such as class indices. -
RandomClassSampler
: This sampler randomly chooses a class and then picks a random example from that class. -
ClassCycleSampler
: As the name suggests, it cycles through each class and fetches a random example from the current class. -
LeastSampledClassSampler
: Chooses the class with the least number of samples fetched so far and retrieves a random example from that class.
Usage
Installation:
This package is installable via pip:
pip install pytorch-multilabel-balanced-sampler
Initialization:
For all samplers, the initialization arguments are:
labels
: A 2D tensor of shape(n_examples, n_classes)
containing the one-hot encoded labels for the dataset.indices
: A sequence of integers representing the indices of the dataset. Default is the range of the dataset size.
from pytorch_multilabel_balanced_sampler.samplers import RandomClassSampler, ClassCycleSampler, LeastSampledClassSampler
sampler1 = RandomClassSampler(labels=my_labels, indices=my_indices)
sampler2 = ClassCycleSampler(labels=my_labels)
sampler3 = LeastSampledClassSampler(labels=my_labels, indices=my_indices)
Fetching samples:
Iterate over the sampler object to fetch samples:
for sample in sampler1:
print(sample)
Note:
All samplers are inherited from BaseMultilabelBalancedRandomSampler
, which in turn inherits from PyTorch's Sampler
class. This ensures compatibility with PyTorch's data loading utilities.
License
The MIT License (MIT). License
Feedback & Issues
For feedback, issues, or feature requests, please raise an issue on the GitHub repository.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for pytorch_multilabel_balanced_sampler-0.1.3.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 06a52a5106f9cabd2a7630b5d606119a7956cce009e8e891a753a311037b1357 |
|
MD5 | 57d94a5b18f687f221f70fb991ff631c |
|
BLAKE2b-256 | 347957965267d8e0f27a93501d30a74d9a2c954918182eb8d8a587c7b667445d |
Hashes for pytorch_multilabel_balanced_sampler-0.1.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fb627bd394dc7e50d21f12f09998131a92bab11ea90e00310ac36aa051663887 |
|
MD5 | 042387e94b3630661551017395d05a6a |
|
BLAKE2b-256 | ff1b48b250f91a0d02b7f56dc59a4fe7f4808144affbd13a2d2f30a04f7edbf7 |