Easy to use class-balanced cross-entropy and focal loss implementation for Pytorch.
Project description
Easy-to-use, class-balanced, cross-entropy and focal loss implementation for Pytorch.
Theory
When training dataset labels are imbalanced, one thing to do is to balance the loss across sample classes.
- First, the effective number of samples are calculated for all classes as:
- Then the class balanced loss function is defined as:
Installation
pip install balanced-loss
Usage
- Standard losses:
import torch
from balanced_loss import Loss
# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0]) # 1 batch
# focal loss
focal_loss = Loss(loss_type="focal_loss")
loss = focal_loss(logits, labels)
# cross-entropy loss
ce_loss = Loss(loss_type="cross_entropy")
loss = ce_loss(logits, labels)
# binary cross-entropy loss
bce_loss = Loss(loss_type="binary_cross_entropy")
loss = bce_loss(logits, labels)
- Class-balanced losses:
import torch
from balanced_loss import Loss
# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0]) # 1 batch
# number of samples per class in the training dataset
samples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively
# class-balanced focal loss
focal_loss = Loss(
loss_type="focal_loss",
samples_per_class=samples_per_class,
class_balanced=True
)
loss = focal_loss(logits, labels)
# class-balanced cross-entropy loss
ce_loss = Loss(
loss_type="cross_entropy",
samples_per_class=samples_per_class,
class_balanced=True
)
loss = ce_loss(logits, labels)
# class-balanced binary cross-entropy loss
bce_loss = Loss(
loss_type="binary_cross_entropy",
samples_per_class=samples_per_class,
class_balanced=True
)
loss = bce_loss(logits, labels)
- Customize parameters:
import torch
from balanced_loss import Loss
# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0])
# number of samples per class in the training dataset
samples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively
# class-balanced focal loss
focal_loss = Loss(
loss_type="focal_loss",
beta=0.999, # class-balanced loss beta
fl_gamma=2, # focal loss gamma
samples_per_class=samples_per_class,
class_balanced=True
)
loss = focal_loss(logits, labels)
Improvements
What is the difference between this repo and vandit15's?
- This repo is a pypi installable package
- This repo implements loss functions as
torch.nn.Module - In addition to class balanced losses, this repo also supports the standard versions of the cross entropy/focal loss etc. over the same API
- All typos and errors in vandit15's source are fixed
- Continiously tested on PyTorch 1.13.1 and 2.5.1
References
https://arxiv.org/abs/1901.05555
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file balanced_loss-0.1.1.tar.gz.
File metadata
- Download URL: balanced_loss-0.1.1.tar.gz
- Upload date:
- Size: 6.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f282420d5e743f530c818d89a3b114746bad2d01a929155d6cfb1667f1cd60fb
|
|
| MD5 |
56a037b32aef34660ceea8e4f0c5e02f
|
|
| BLAKE2b-256 |
879ec8d9e2a1df92968f7f7c2f440431363240df9004795f8e407c8efab7076d
|
File details
Details for the file balanced_loss-0.1.1-py3-none-any.whl.
File metadata
- Download URL: balanced_loss-0.1.1-py3-none-any.whl
- Upload date:
- Size: 5.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1e7f993c6751a52d9c2aa250f091ba37acc558cc576d9756e5072922e6258a13
|
|
| MD5 |
b970a5a5ae5f55ca4531adffd6ce9e94
|
|
| BLAKE2b-256 |
fe4ba6d4f86228c88c71637da4729723844df4c40185598e84d6125f386766ab
|