Skip to main content

A simple PyTorch implementation of focal loss.

Project description

Documentation Status

focal-loss-pytorch

Simple vectorized PyTorch implementation of binary unweighted focal loss as specified by [1].

Installation

This package can be installed using pip as follows:

python3 -m pip install focal-loss-pytorch

Example Usage

Here is a quick example of how to import the BinaryFocalLoss class and use it to train a model:

from focal_loss_pytorch.focal_loss_pytorch.focal_loss import BinaryFocalLoss
import torch

#Initialize device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#Initialize loss fn +  optimizer 
loss_fn = BinaryFocalLoss(gamma=5)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

#Load datasets
train_loader = DataLoader(train_set, batch_size=10, shuffle=False)
val_loader = DataLoader(val_set, batch_size=10, shuffle=False)

#Train! :)
for e in range(epochs):
   for data in train_loader:
      model.train()
      input_img = data['img'].to(device)
      ref_img = data['ref'].to(device)
      output_img = model(input_img)
            
      loss = loss_fn(output_img, ref_img)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

Documentation

Documentation for this package is available on Read the Docs.

References

[1] Lin, T. Y., et al. "Focal loss for dense object detection." arXiv 2017." arXiv preprint arXiv:1708.02002 (2002).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

focal_loss_pytorch-0.0.3.tar.gz (14.5 kB view details)

Uploaded Source

Built Distribution

focal_loss_pytorch-0.0.3-py3-none-any.whl (14.4 kB view details)

Uploaded Python 3

File details

Details for the file focal_loss_pytorch-0.0.3.tar.gz.

File metadata

  • Download URL: focal_loss_pytorch-0.0.3.tar.gz
  • Upload date:
  • Size: 14.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for focal_loss_pytorch-0.0.3.tar.gz
Algorithm Hash digest
SHA256 5f3817152cbf07f6fb7374097f2ba0de24c8a197e695e1172f361cbf92a7a55f
MD5 6e9fcd2ee5156f600f9af979c95dc8c4
BLAKE2b-256 fa5bbf6154bfa8b2004fb995741f09fdab6bc55ffed15a0081d3c83cf021b620

See more details on using hashes here.

File details

Details for the file focal_loss_pytorch-0.0.3-py3-none-any.whl.

File metadata

File hashes

Hashes for focal_loss_pytorch-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 42e6f325ea58df1a7824898457ec732f95c0f82948ca6fddd73f823c480e2c0d
MD5 8f563a6e56cbb3056f9af0896b9f580f
BLAKE2b-256 eeb8cba34303d9ae307750561de5e12f66c805d5b13cc13b60d982b8c9a575d0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page