PyTorch samplers that output roughly balanced batches with support for multilabel datasets.
Project description
PyTorch Multilabel Balanced Samplers
This package provides samplers to fetch data samples from multilabel datasets in a balanced manner. Balanced sampling from multilabel datasets can be especially useful to handle class imbalance issues.
Samplers
-
BaseMultilabelBalancedRandomSampler
: This is the base class for all the provided samplers. It initializes the basic structure required for sampling, such as class indices. -
RandomClassSampler
: This sampler randomly chooses a class and then picks a random example from that class. -
ClassCycleSampler
: As the name suggests, it cycles through each class and fetches a random example from the current class. -
LeastSampledClassSampler
: Chooses the class with the least number of samples fetched so far and retrieves a random example from that class.
Usage
Installation:
This package is installable via pip:
pip install pytorch-multilabel-balanced-sampler
Initialization:
For all samplers, the initialization arguments are:
labels
: A 2D tensor of shape(n_examples, n_classes)
containing the one-hot encoded labels for the dataset.indices
: A sequence of integers representing the indices of the dataset. Default is the range of the dataset size.
from pytorch_multilabel_balanced_sampler.samplers import RandomClassSampler, ClassCycleSampler, LeastSampledClassSampler
sampler1 = RandomClassSampler(labels=my_labels, indices=my_indices)
sampler2 = ClassCycleSampler(labels=my_labels)
sampler3 = LeastSampledClassSampler(labels=my_labels, indices=my_indices)
Fetching samples:
Iterate over the sampler object to fetch samples:
for sample in sampler1:
print(sample)
Note:
All samplers are inherited from BaseMultilabelBalancedRandomSampler
, which in turn inherits from PyTorch's Sampler
class. This ensures compatibility with PyTorch's data loading utilities.
License
The MIT License (MIT). License
Feedback & Issues
For feedback, issues, or feature requests, please raise an issue on the GitHub repository.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pytorch_multilabel_balanced_sampler-0.1.3.tar.gz
.
File metadata
- Download URL: pytorch_multilabel_balanced_sampler-0.1.3.tar.gz
- Upload date:
- Size: 4.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 06a52a5106f9cabd2a7630b5d606119a7956cce009e8e891a753a311037b1357 |
|
MD5 | 57d94a5b18f687f221f70fb991ff631c |
|
BLAKE2b-256 | 347957965267d8e0f27a93501d30a74d9a2c954918182eb8d8a587c7b667445d |
File details
Details for the file pytorch_multilabel_balanced_sampler-0.1.3-py3-none-any.whl
.
File metadata
- Download URL: pytorch_multilabel_balanced_sampler-0.1.3-py3-none-any.whl
- Upload date:
- Size: 4.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fb627bd394dc7e50d21f12f09998131a92bab11ea90e00310ac36aa051663887 |
|
MD5 | 042387e94b3630661551017395d05a6a |
|
BLAKE2b-256 | ff1b48b250f91a0d02b7f56dc59a4fe7f4808144affbd13a2d2f30a04f7edbf7 |