Skip to main content

Easy to use class-balanced cross-entropy and focal loss implementation for Pytorch.

Project description

pypi version total downloads fcakyon twitter

Easy-to-use, class-balanced, cross-entropy and focal loss implementation for Pytorch.

Theory

When training dataset labels are imbalanced, one thing to do is to balance the loss across sample classes.

  • First, the effective number of samples are calculated for all classes as:

alt-text

  • Then the class balanced loss function is defined as:

alt-text

Installation

pip install balanced-loss

Usage

  • Standard losses:
import torch
from balanced_loss import Loss

# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0]) # 1 batch

# focal loss
focal_loss = Loss(loss_type="focal_loss")
loss = focal_loss(logits, labels)
# cross-entropy loss
ce_loss = Loss(loss_type="cross_entropy")
loss = ce_loss(logits, labels)
# binary cross-entropy loss
bce_loss = Loss(loss_type="binary_cross_entropy")
loss = bce_loss(logits, labels)
  • Class-balanced losses:
import torch
from balanced_loss import Loss

# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0]) # 1 batch

# number of samples per class in the training dataset
samples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively

# class-balanced focal loss
focal_loss = Loss(
    loss_type="focal_loss",
    samples_per_class=samples_per_class,
    class_balanced=True
)
loss = focal_loss(logits, labels)
# class-balanced cross-entropy loss
ce_loss = Loss(
    loss_type="cross_entropy",
    samples_per_class=samples_per_class,
    class_balanced=True
)
loss = ce_loss(logits, labels)
# class-balanced binary cross-entropy loss
bce_loss = Loss(
    loss_type="binary_cross_entropy",
    samples_per_class=samples_per_class,
    class_balanced=True
)
loss = bce_loss(logits, labels)
  • Customize parameters:
import torch
from balanced_loss import Loss

# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0])

# number of samples per class in the training dataset
samples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively

# class-balanced focal loss
focal_loss = Loss(
    loss_type="focal_loss",
    beta=0.999, # class-balanced loss beta
    fl_gamma=2, # focal loss gamma
    samples_per_class=samples_per_class,
    class_balanced=True
)
loss = focal_loss(logits, labels)

Improvements

What is the difference between this repo and vandit15's?

  • This repo is a pypi installable package
  • This repo implements loss functions as torch.nn.Module
  • In addition to class balanced losses, this repo also supports the standard versions of the cross entropy/focal loss etc. over the same API
  • All typos and errors in vandit15's source are fixed
  • Continiously tested on PyTorch 1.13.1 and 2.5.1

References

https://arxiv.org/abs/1901.05555

https://github.com/richardaecn/class-balanced-loss

https://github.com/vandit15/Class-balanced-loss-pytorch

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

balanced_loss-0.1.1.tar.gz (6.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

balanced_loss-0.1.1-py3-none-any.whl (5.5 kB view details)

Uploaded Python 3

File details

Details for the file balanced_loss-0.1.1.tar.gz.

File metadata

  • Download URL: balanced_loss-0.1.1.tar.gz
  • Upload date:
  • Size: 6.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.12.8

File hashes

Hashes for balanced_loss-0.1.1.tar.gz
Algorithm Hash digest
SHA256 f282420d5e743f530c818d89a3b114746bad2d01a929155d6cfb1667f1cd60fb
MD5 56a037b32aef34660ceea8e4f0c5e02f
BLAKE2b-256 879ec8d9e2a1df92968f7f7c2f440431363240df9004795f8e407c8efab7076d

See more details on using hashes here.

File details

Details for the file balanced_loss-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: balanced_loss-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 5.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.12.8

File hashes

Hashes for balanced_loss-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 1e7f993c6751a52d9c2aa250f091ba37acc558cc576d9756e5072922e6258a13
MD5 b970a5a5ae5f55ca4531adffd6ce9e94
BLAKE2b-256 fe4ba6d4f86228c88c71637da4729723844df4c40185598e84d6125f386766ab

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page