A simple PyTorch implementation of focal loss.
Project description
focal-loss-pytorch
Simple vectorized PyTorch implementation of binary unweighted focal loss as specified by [1].
Installation
This package can be installed using pip as follows:
python3 -m pip install focal-loss-pytorch
Example Usage
Here is a quick example of how to import the BinaryFocalLoss class and use it to train a model:
from focal_loss_pytorch.focal_loss_pytorch.focal_loss import BinaryFocalLoss
import torch
#Initialize device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#Initialize loss fn + optimizer
loss_fn = BinaryFocalLoss(gamma=5)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#Load datasets
train_loader = DataLoader(train_set, batch_size=10, shuffle=False)
val_loader = DataLoader(val_set, batch_size=10, shuffle=False)
#Train! :)
for e in range(epochs):
for data in train_loader:
model.train()
input_img = data['img'].to(device)
ref_img = data['ref'].to(device)
output_img = model(input_img)
loss = loss_fn(output_img, ref_img)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Documentation
Documentation for this package is available on Read the Docs.
References
[1] Lin, T. Y., et al. "Focal loss for dense object detection." arXiv 2017." arXiv preprint arXiv:1708.02002 (2002).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file focal_loss_pytorch-0.0.3.tar.gz.
File metadata
- Download URL: focal_loss_pytorch-0.0.3.tar.gz
- Upload date:
- Size: 14.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5f3817152cbf07f6fb7374097f2ba0de24c8a197e695e1172f361cbf92a7a55f
|
|
| MD5 |
6e9fcd2ee5156f600f9af979c95dc8c4
|
|
| BLAKE2b-256 |
fa5bbf6154bfa8b2004fb995741f09fdab6bc55ffed15a0081d3c83cf021b620
|
File details
Details for the file focal_loss_pytorch-0.0.3-py3-none-any.whl.
File metadata
- Download URL: focal_loss_pytorch-0.0.3-py3-none-any.whl
- Upload date:
- Size: 14.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
42e6f325ea58df1a7824898457ec732f95c0f82948ca6fddd73f823c480e2c0d
|
|
| MD5 |
8f563a6e56cbb3056f9af0896b9f580f
|
|
| BLAKE2b-256 |
eeb8cba34303d9ae307750561de5e12f66c805d5b13cc13b60d982b8c9a575d0
|