Skip to main content

A simple PyTorch implementation of focal loss.

Project description

Documentation Status

focal-loss-pytorch

Simple vectorized PyTorch implementation of binary unweighted focal loss as specified by [1].

Installation

This package can be installed using pip as follows:

python3 -m pip install focal-loss-pytorch

Example Usage

Here is a quick example of how to import the BinaryFocalLoss class and use it to train a model:

from focal_loss_pytorch.focal_loss_pytorch.focal_loss import BinaryFocalLoss
import torch

#Initialize device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#Initialize loss fn +  optimizer 
loss_fn = BinaryFocalLoss(gamma=5)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

#Load datasets
train_loader = DataLoader(train_set, batch_size=10, shuffle=False)
val_loader = DataLoader(val_set, batch_size=10, shuffle=False)

#Train! :)
for e in range(epochs):
   for data in train_loader:
      model.train()
      input_img = data['img'].to(device)
      ref_img = data['ref'].to(device)
      output_img = model(input_img)
            
      loss = loss_fn(output_img, ref_img)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

Documentation

Documentation for this package is available on Read the Docs.

References

[1] Lin, T. Y., et al. "Focal loss for dense object detection." arXiv 2017." arXiv preprint arXiv:1708.02002 (2002).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

focal_loss_pytorch-0.0.3.tar.gz (14.5 kB view hashes)

Uploaded Source

Built Distribution

focal_loss_pytorch-0.0.3-py3-none-any.whl (14.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page