Skip to main content

Simple pytorch implementation of focal loss

Project description

focal_loss_torch

Simple pytorch implementation of focal loss introduced by Lin et al [1].

Usage

Install the package using pip

pip install focal_loss_torch

Focal loss is now accessible in your pytorch environment:

from focal_loss.focal_loss import FocalLoss

# Withoout class weights
criterion = FocalLoss(gamma=0.7)

# with weights 
# The weights parameter is similar to the alpha value mentioned in the paper
weights = torch.FloatTensor([2, 3.2, 0.7])
criterion = FocalLoss(gamma=0.7, weights=weights)

# to ignore index 
criterion = FocalLoss(gamma=0.7, ignore_index=0)

# To make it behaves as CrossEntropy loss
criterion = FocalLoss(gamma=0)

Examples

For Binary-classification

batch_size = 10
m = torch.nn.Sigmoid()
logits = torch.randn(batch_size)
target = torch.randint(0, 2, size=(batch_size,))
loss = criterion(m(logits), target)

For Multi-Class classification

batch_size = 10
n_class = 5
m = torch.nn.Softmax(dim=-1)
logits = torch.randn(batch_size, n_class)
target = torch.randint(0, n_class, size=(batch_size,))
criterion(m(logits), target)

For Multi-Class Sequence classification

batch_size = 10
max_length = 20
n_class = 5
m = torch.nn.Softmax(dim=-1)
logits = torch.randn(batch_size, max_length, n_class)
target = torch.randint(0, n_class, size=(batch_size, max_length))
criterion(m(logits), target)

Contributions

Contributions, criticism or corrections are always welcome. Just send me a pull request!

References

[1] Lin, T. Y., et al. "Focal loss for dense object detection." arXiv 2017." arXiv preprint arXiv:1708.02002 (2002).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

focal_loss_torch-0.1.2.tar.gz (4.1 kB view details)

Uploaded Source

Built Distribution

focal_loss_torch-0.1.2-py3-none-any.whl (4.5 kB view details)

Uploaded Python 3

File details

Details for the file focal_loss_torch-0.1.2.tar.gz.

File metadata

  • Download URL: focal_loss_torch-0.1.2.tar.gz
  • Upload date:
  • Size: 4.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for focal_loss_torch-0.1.2.tar.gz
Algorithm Hash digest
SHA256 837938c411e8b215a89f4a14c00fbb36216562a3390163755d2142857569c51e
MD5 336642f08d19cd4b0bd96f0812ac9840
BLAKE2b-256 662a41a5b59ba1040f43f2d50c1f1e79bafa4efaf9ca58327aa56d5dc3b22760

See more details on using hashes here.

File details

Details for the file focal_loss_torch-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for focal_loss_torch-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 cdbf6f2429d41f6f385695ea75a7ef089211986e94febb1007ca1c6dc3329476
MD5 7673bb08bf0c99c64adc9fa63cb5c70a
BLAKE2b-256 ee8a32a0ec54b86f557f2ddce695b89634c94f74a6288fa8d0b0b5260383e37c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page