Simple pytorch implementation of focal loss
Project description
focal_loss_torch
Simple pytorch implementation of focal loss introduced by Lin et al [1].
Usage
Install the package using pip
pip install focal_loss_torch
Focal loss is now accessible in your pytorch environment:
from focal_loss.focal_loss import FocalLoss
# Withoout class weights
criterion = FocalLoss(gamma=0.7)
# with weights
# The weights parameter is similar to the alpha value mentioned in the paper
weights = torch.FloatTensor([2, 3.2, 0.7])
criterion = FocalLoss(gamma=0.7, weights=weights)
# to ignore index
criterion = FocalLoss(gamma=0.7, ignore_index=0)
# To make it behaves as CrossEntropy loss
criterion = FocalLoss(gamma=0)
Examples
For Binary-classification
batch_size = 10
m = torch.nn.Sigmoid()
logits = torch.randn(batch_size)
target = torch.randint(0, 2, size=(batch_size,))
loss = criterion(m(logits), target)
For Multi-Class classification
batch_size = 10
n_class = 5
m = torch.nn.Softmax(dim=-1)
logits = torch.randn(batch_size, n_class)
target = torch.randint(0, n_class, size=(batch_size,))
criterion(m(logits), target)
For Multi-Class Sequence classification
batch_size = 10
max_length = 20
n_class = 5
m = torch.nn.Softmax(dim=-1)
logits = torch.randn(batch_size, max_length, n_class)
target = torch.randint(0, n_class, size=(batch_size, max_length))
criterion(m(logits), target)
Contributions
Contributions, criticism or corrections are always welcome. Just send me a pull request!
References
[1] Lin, T. Y., et al. "Focal loss for dense object detection." arXiv 2017." arXiv preprint arXiv:1708.02002 (2002).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
focal_loss_torch-0.1.2.tar.gz
(4.1 kB
view hashes)
Built Distribution
Close
Hashes for focal_loss_torch-0.1.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | cdbf6f2429d41f6f385695ea75a7ef089211986e94febb1007ca1c6dc3329476 |
|
MD5 | 7673bb08bf0c99c64adc9fa63cb5c70a |
|
BLAKE2b-256 | ee8a32a0ec54b86f557f2ddce695b89634c94f74a6288fa8d0b0b5260383e37c |