A simple PyTorch implementation of focal loss.
Project description
focal-loss-pytorch
Simple vectorized PyTorch implementation of binary unweighted focal loss as specified by [1].
Installation
This package can be installed using pip as follows:
python3 -m pip install focal-loss-pytorch
Example Usage
Here is a quick example of how to import the BinaryFocalLoss class and use it to train a model:
from focal_loss_pytorch.focal_loss_pytorch.focal_loss import BinaryFocalLoss
import torch
#Initialize device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#Initialize loss fn + optimizer
loss_fn = BinaryFocalLoss(gamma=5)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#Load datasets
train_loader = DataLoader(train_set, batch_size=10, shuffle=False)
val_loader = DataLoader(val_set, batch_size=10, shuffle=False)
#Train! :)
for e in range(epochs):
for data in train_loader:
model.train()
input_img = data['img'].to(device)
ref_img = data['ref'].to(device)
output_img = model(input_img)
loss = loss_fn(output_img, ref_img)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Documentation
Documentation for this package is available on Read the Docs.
References
[1] Lin, T. Y., et al. "Focal loss for dense object detection." arXiv 2017." arXiv preprint arXiv:1708.02002 (2002).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
focal_loss_pytorch-0.0.3.tar.gz
(14.5 kB
view details)
Built Distribution
File details
Details for the file focal_loss_pytorch-0.0.3.tar.gz
.
File metadata
- Download URL: focal_loss_pytorch-0.0.3.tar.gz
- Upload date:
- Size: 14.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5f3817152cbf07f6fb7374097f2ba0de24c8a197e695e1172f361cbf92a7a55f |
|
MD5 | 6e9fcd2ee5156f600f9af979c95dc8c4 |
|
BLAKE2b-256 | fa5bbf6154bfa8b2004fb995741f09fdab6bc55ffed15a0081d3c83cf021b620 |
File details
Details for the file focal_loss_pytorch-0.0.3-py3-none-any.whl
.
File metadata
- Download URL: focal_loss_pytorch-0.0.3-py3-none-any.whl
- Upload date:
- Size: 14.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 42e6f325ea58df1a7824898457ec732f95c0f82948ca6fddd73f823c480e2c0d |
|
MD5 | 8f563a6e56cbb3056f9af0896b9f580f |
|
BLAKE2b-256 | eeb8cba34303d9ae307750561de5e12f66c805d5b13cc13b60d982b8c9a575d0 |