An implementation of loss functions from "Unified Focal loss: Generalising Dice and cross entropy-based losses to handle class imbalanced medical image segmentation"
Project description
Unified Focal Loss PyTorch
An implementation of loss functions from “Unified Focal loss: Generalising Dice and cross entropy-based losses to handle class imbalanced medical image segmentation”
Extended for multiclass classification and to allow passing an ignore index.
Note: This implementation is not tested against the original implementation. It varies from the original implementation based on my own interpretation of the paper.
Installation
pip install unified-focal-loss-pytorch
Usage
import torch
import torch.nn.functional as F
from unified_focal_loss import AsymmetricUnifiedFocalLoss
loss_fn = AsymmetricUnifiedFocalLoss(
delta=0.7,
gamma=0.5,
ignore_index=2,
)
logits = torch.tensor([
[[0.1000, 0.4000],
[0.2000, 0.5000],
[0.3000, 0.6000]],
[[0.7000, 0.0000],
[0.8000, 0.1000],
[0.9000, 0.2000]]
])
# Shape should be (batch_size, num_classes, ...)
probs = F.softmax(logits, dim=1)
# Shape should be (batch_size, ...). Not one-hot encoded.
targets = torch.tensor([
[0, 1],
[2, 0],
])
loss = loss_fn(probs, targets)
print(loss)
# >>> tensor(0.6737)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for unified_focal_loss_pytorch-0.1.1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 565e5ea30d7974c2ed9912e6b95aaabad00cd3af6f82c9a98948cb4c7ee830a2 |
|
MD5 | b47c9a8d1660c5f5b3acf7f7b1719770 |
|
BLAKE2b-256 | 23f6f0a8613a6b35c45789788e1ef351eb9df61c14904f6991a42a3301a8eeb2 |
Close
Hashes for unified_focal_loss_pytorch-0.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8c6f7f414471676b8b87958d11ac6c4c1bd473d9638b89d3ddb100957a072c8c |
|
MD5 | 5704034e1a12a003a50cfc1d8ad27f85 |
|
BLAKE2b-256 | 56b8ed98db441909c5aa933fe8958927d7009d04e3eb48f6e0172727769732ff |