Skip to main content

Easy to use class-balanced cross-entropy and focal loss implementation for Pytorch.

Project description

pypi version

Easy-to-use, class-balanced, cross-entropy and focal loss implementation for Pytorch.

Theory

When training dataset labels are imbalanced, one thing to do is to balance the loss across sample classes.

  • First, the effective number of samples are calculated for all classes as:

alt-text

  • Then the class balanced loss function is defined as:

alt-text

Installation

pip install balanced-loss

Usage

  • Standard losses:
import torch
from balanced_loss import Loss

# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0]) # 1 batch

# focal loss
focal_loss = Loss(loss_type="focal_loss")
loss = focal_loss(logits, labels)
# cross-entropy loss
ce_loss = Loss(loss_type="cross_entropy")
loss = ce_loss(logits, labels)
# binary cross-entropy loss
bce_loss = Loss(loss_type="binary_cross_entropy")
loss = bce_loss(logits, labels)
  • Class-balanced losses:
import torch
from balanced_loss import Loss

# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0]) # 1 batch

# number of samples per class in the training dataset
samples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively

# class-balanced focal loss
focal_loss = Loss(
    loss_type="focal_loss",
    samples_per_class=samples_per_class,
    class_balanced=True
)
loss = focal_loss(logits, labels)
# class-balanced cross-entropy loss
ce_loss = Loss(
    loss_type="cross_entropy",
    samples_per_class=samples_per_class,
    class_balanced=True
)
loss = ce_loss(logits, labels)
# class-balanced binary cross-entropy loss
bce_loss = Loss(
    loss_type="binary_cross_entropy",
    samples_per_class=samples_per_class,
    class_balanced=True
)
loss = bce_loss(logits, labels)
  • Customize parameters:
import torch
from balanced_loss import Loss

# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0])

# number of samples per class in the training dataset
samples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively

# class-balanced focal loss
focal_loss = Loss(
    loss_type="focal_loss",
    beta=0.999, # class-balanced loss beta
    fl_gamma=2, # focal loss gamma
    samples_per_class=samples_per_class,
    class_balanced=True
)
loss = focal_loss(logits, labels)

Improvements

What is the difference between this repo and vandit15's?

  • This repo is a pypi installable package
  • This repo implements loss functions as torch.nn.Module
  • In addition to class balanced losses, this repo also supports the standard versions of the cross entropy/focal loss etc. over the same API
  • All typos and errors in vandit15's source are fixed

References

https://arxiv.org/abs/1901.05555

https://github.com/richardaecn/class-balanced-loss

https://github.com/vandit15/Class-balanced-loss-pytorch

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

balanced-loss-0.1.0.tar.gz (5.4 kB view details)

Uploaded Source

Built Distribution

balanced_loss-0.1.0-py3-none-any.whl (5.2 kB view details)

Uploaded Python 3

File details

Details for the file balanced-loss-0.1.0.tar.gz.

File metadata

  • Download URL: balanced-loss-0.1.0.tar.gz
  • Upload date:
  • Size: 5.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.5

File hashes

Hashes for balanced-loss-0.1.0.tar.gz
Algorithm Hash digest
SHA256 e55957cda998ed84963b8aa9c4f32456a7edb4fd94a5938d17604bb7763dff07
MD5 9b030f64c65a8cf2e789768ac4890b9d
BLAKE2b-256 264a7fbab9ae35b9c490fbfe574c2247dfb1af32bba438c59443bb26ae983403

See more details on using hashes here.

File details

Details for the file balanced_loss-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for balanced_loss-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9504e5d52dc3773d701f0af07090e470b155eb77060fd00c1b0ac6fbff68f10c
MD5 1d191e256902fab2fadbe25a3962a7cb
BLAKE2b-256 7ea7171d43fae753004d156b008d9db32458c487203df888841c5b2bc4f3f310

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page