Easy to use class-balanced cross-entropy and focal loss implementation for Pytorch.
Project description
Easy-to-use, class-balanced, cross-entropy and focal loss implementation for Pytorch.
Theory
When training dataset labels are imbalanced, one thing to do is to balance the loss across sample classes.
- First, the effective number of samples are calculated for all classes as:
- Then the class balanced loss function is defined as:
Installation
pip install balanced-loss
Usage
- Standard losses:
import torch
from balanced_loss import Loss
# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0]) # 1 batch
# focal loss
focal_loss = Loss(loss_type="focal_loss")
loss = focal_loss(logits, labels)
# cross-entropy loss
ce_loss = Loss(loss_type="cross_entropy")
loss = ce_loss(logits, labels)
# binary cross-entropy loss
bce_loss = Loss(loss_type="binary_cross_entropy")
loss = bce_loss(logits, labels)
- Class-balanced losses:
import torch
from balanced_loss import Loss
# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0]) # 1 batch
# number of samples per class in the training dataset
samples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively
# class-balanced focal loss
focal_loss = Loss(
loss_type="focal_loss",
samples_per_class=samples_per_class,
class_balanced=True
)
loss = focal_loss(logits, labels)
# class-balanced cross-entropy loss
ce_loss = Loss(
loss_type="cross_entropy",
samples_per_class=samples_per_class,
class_balanced=True
)
loss = ce_loss(logits, labels)
# class-balanced binary cross-entropy loss
bce_loss = Loss(
loss_type="binary_cross_entropy",
samples_per_class=samples_per_class,
class_balanced=True
)
loss = bce_loss(logits, labels)
- Customize parameters:
import torch
from balanced_loss import Loss
# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0])
# number of samples per class in the training dataset
samples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively
# class-balanced focal loss
focal_loss = Loss(
loss_type="focal_loss",
beta=0.999, # class-balanced loss beta
fl_gamma=2, # focal loss gamma
samples_per_class=samples_per_class,
class_balanced=True
)
loss = focal_loss(logits, labels)
Improvements
What is the difference between this repo and vandit15's?
- This repo is a pypi installable package
- This repo implements loss functions as
torch.nn.Module
- In addition to class balanced losses, this repo also supports the standard versions of the cross entropy/focal loss etc. over the same API
- All typos and errors in vandit15's source are fixed
References
https://arxiv.org/abs/1901.05555
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
balanced-loss-0.1.0.tar.gz
(5.4 kB
view hashes)
Built Distribution
Close
Hashes for balanced_loss-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9504e5d52dc3773d701f0af07090e470b155eb77060fd00c1b0ac6fbff68f10c |
|
MD5 | 1d191e256902fab2fadbe25a3962a7cb |
|
BLAKE2b-256 | 7ea7171d43fae753004d156b008d9db32458c487203df888841c5b2bc4f3f310 |