Skip to main content

An implementation of loss functions from "Unified Focal loss: Generalising Dice and cross entropy-based losses to handle class imbalanced medical image segmentation"

Project description

Unified Focal Loss PyTorch

An implementation of loss functions from “Unified Focal loss: Generalising Dice and cross entropy-based losses to handle class imbalanced medical image segmentation”

Extended for multiclass classification and to allow passing an ignore index.

Note: This implementation is not tested against the original implementation. It varies from the original implementation based on my own interpretation of the paper.

Installation

pip install unified-focal-loss-pytorch

Usage

import torch
import torch.nn.functional as F
from unified_focal_loss import AsymmetricUnifiedFocalLoss

loss_fn = AsymmetricUnifiedFocalLoss(
    delta=0.7,
    gamma=0.5,
    ignore_index=2,
)

logits = torch.tensor([
    [[0.1000, 0.4000],
     [0.2000, 0.5000],
     [0.3000, 0.6000]],

    [[0.7000, 0.0000],
     [0.8000, 0.1000],
     [0.9000, 0.2000]]
])

# Shape should be (batch_size, num_classes, ...)
probs = F.softmax(logits, dim=1)
# Shape should be (batch_size, ...). Not one-hot encoded.
targets = torch.tensor([
    [0, 1],
    [2, 0],
])

loss = loss_fn(probs, targets)
print(loss)
# >>> tensor(0.6737)

Detailed API Reference

See API docs.

License

See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

unified_focal_loss_pytorch-0.1.2.tar.gz (6.0 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file unified_focal_loss_pytorch-0.1.2.tar.gz.

File metadata

  • Download URL: unified_focal_loss_pytorch-0.1.2.tar.gz
  • Upload date:
  • Size: 6.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.10.12 Linux/5.15.0-1041-azure

File hashes

Hashes for unified_focal_loss_pytorch-0.1.2.tar.gz
Algorithm Hash digest
SHA256 d43be8e91943bd951ee7353e6a64b07a296373827de2374cbb0126ed8d30ab01
MD5 d8ba46ff13474867cddd6f4cbd2a159f
BLAKE2b-256 6adc581e25141b9068933bf5edad4c790ebe34868fd82fb19ff2a81173b6a5a8

See more details on using hashes here.

File details

Details for the file unified_focal_loss_pytorch-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for unified_focal_loss_pytorch-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 0631b618d10c9537a0385f618c78ad6f1ecec8d68489f9812d2f76a54c5b73b6
MD5 0892020a484a22a59ab6094dcc147cbc
BLAKE2b-256 fe15756875ed41bd147d7e7a9eee5bb76429b83ce3867c0cfff4d1b0cda04766

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page