Skip to main content

Simple pytorch implementation of focal loss

Project description

focal_loss_torch

Simple pytorch implementation of focal loss introduced by Lin et al [1].

Usage

Install the package using pip

pip install focal_loss_torch

Focal loss is now accessible in your pytorch environment:

from focal_loss.focal_loss import FocalLoss

# Withoout class weights
criterion = FocalLoss(gamma=0.7)

# with weights 
# The weights parameter is similar to the alpha value mentioned in the paper
weights = torch.FloatTensor([2, 3.2, 0.7])
criterion = FocalLoss(gamma=0.7, weights=weights)

# to ignore index 
criterion = FocalLoss(gamma=0.7, ignore_index=0)

# To make it behaves as CrossEntropy loss
criterion = FocalLoss(gamma=0)

Examples

For Binary-classification

batch_size = 10
m = torch.nn.Sigmoid()
logits = torch.randn(batch_size)
target = torch.randint(0, 2, size=(batch_size,))
loss = criterion(m(logits), target)

For Multi-Class classification

batch_size = 10
n_class = 5
m = torch.nn.Softmax(dim=-1)
logits = torch.randn(batch_size, n_class)
target = torch.randint(0, n_class, size=(batch_size,))
criterion(m(logits), target)

For Multi-Class Sequence classification

batch_size = 10
max_length = 20
n_class = 5
m = torch.nn.Softmax(dim=-1)
logits = torch.randn(batch_size, max_length, n_class)
target = torch.randint(0, n_class, size=(batch_size, max_length))
criterion(m(logits), target)

Contributions

Contributions, criticism or corrections are always welcome. Just send me a pull request!

References

[1] Lin, T. Y., et al. "Focal loss for dense object detection." arXiv 2017." arXiv preprint arXiv:1708.02002 (2002).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

focal_loss_torch-0.1.1.tar.gz (3.9 kB view details)

Uploaded Source

Built Distribution

focal_loss_torch-0.1.1-py3-none-any.whl (4.4 kB view details)

Uploaded Python 3

File details

Details for the file focal_loss_torch-0.1.1.tar.gz.

File metadata

  • Download URL: focal_loss_torch-0.1.1.tar.gz
  • Upload date:
  • Size: 3.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.6

File hashes

Hashes for focal_loss_torch-0.1.1.tar.gz
Algorithm Hash digest
SHA256 25cdc1e722ed8a2b33c04ff90a5b67451bac1c18fbd076aaba2bd153eb261fe4
MD5 f30e04d3fbc8a7f8c3c56ff992d60c27
BLAKE2b-256 fa5ee104a0216422d7cad8eda34765d924d9e4876f3acacbeb93f8b2996dbc99

See more details on using hashes here.

File details

Details for the file focal_loss_torch-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for focal_loss_torch-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4105de643c08df3aebf9e77cc69044a98c2f12f9f29c0874e39ec9f0fca207b1
MD5 c7fd5917f755f1dd853b74c35f86a590
BLAKE2b-256 749962be85a90fefe8a9f1c9eee28363197423982733bdd77a7a48722e70a93d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page