imbalanced-losses is a PyTorch library of training losses for class-imbalanced classification — including Focal Loss, Smooth-AP, and Recall-at-Quantile — with built-in DDP all-gather support for globally-correct rank estimation and normalization across multi-GPU training.
Project description
imbalanced-losses
imbalanced-losses is a PyTorch library of training losses for class-imbalanced classification — including Focal Loss, Smooth-AP, and Recall-at-Quantile — with built-in DDP all-gather support for globally-correct rank estimation and normalization across multi-GPU training.
What's in it:
SigmoidFocalLoss— Binary/multi-label focal loss (Lin et al., ICCV 2017). Sigmoid activation;alphare-balances pos/neg,gammadown-weights easy examples. Drop-in replacement forBCEWithLogitsLoss.SoftmaxFocalLoss— Multiclass focal loss with softmax. Supportsmean_positivereduction (RetinaNet convention: normalize by positive count), per-classalphaweighting, label smoothing, and arbitrary spatial/sequence input shapes.SmoothAPLoss— Differentiable approximation of AP (Brown et al., ECCV 2020). Uses sigmoid-based soft rank estimation; O(|P|×M) where |P| is the positive count and M = batch + queue size. Supports multi-class, binary, and seq2seq settings.RecallAtQuantileLoss— Optimizes recall above a score threshold set at the q-th quantile of the pooled distribution. Useful for alert/detection workloads (e.g. top 0.5% of scores).LossWarmupWrapper— Training utility that runs a standard loss (BCE/CE) during warmup, linearly blends into the ranking loss over a configurable transition window, then applies geometric temperature decay. Automatically resets the memory queue at the phase switch to prevent queue poisoning from warmup-era logits.
Design points:
- Circular memory queue stabilizes gradient estimates across small batches — critical at low positive rates (e.g. 0.5%)
- Compatible with PyTorch Lightning via
on_train_epoch_start/on_train_batch_starthooks toy_demo.pydemonstrates the full warmup→blend→AP pipeline on a highly imbalanced binary classification task using sklearn'smake_classification
Losses
SigmoidFocalLoss — Focal Loss, binary / multi-label (Lin et al., 2017)
Replaces BCEWithLogitsLoss for imbalanced binary or multi-label classification. gamma suppresses the contribution of easy (well-classified) examples so training focuses on hard ones; alpha re-weights the positive class:
p_t = sigmoid(logit) · y + (1 − sigmoid(logit)) · (1 − y)
loss = −α_t · (1 − p_t)^γ · log(p_t)
from imbalanced_losses import SigmoidFocalLoss
loss_fn = SigmoidFocalLoss(alpha=0.25, gamma=2.0, reduction="mean")
logits = torch.randn(32, 1) # arbitrary shape
targets = torch.randint(0, 2, (32, 1)).float() # float 0/1
loss = loss_fn(logits, targets)
loss.backward()
SoftmaxFocalLoss — Focal Loss, multiclass (Lin et al., 2017)
Extends focal loss to mutually-exclusive classification via softmax. Supports all standard input shapes (N, C), (N, C, L), (N, C, H, W), etc.
from imbalanced_losses import SoftmaxFocalLoss
# Standard multiclass
loss_fn = SoftmaxFocalLoss(gamma=2.0, reduction="mean")
logits = torch.randn(32, 10) # [N, C]
targets = torch.randint(0, 10, (32,)) # [N] integer labels
loss = loss_fn(logits, targets)
# RetinaNet-style: normalize by positive count, not total
loss_fn = SoftmaxFocalLoss(
gamma=2.0,
alpha=[0.25] * 10, # per-class weights
reduction="mean_positive", # denominator = #positives only
background_class=0,
ignore_index=-100,
)
loss = loss_fn(logits, targets)
mean_positive reduction: The numerator sums loss over all valid (non-ignored) positions including background. The denominator counts only non-background valid positions. This matches the RetinaNet convention and stabilizes the loss scale when the vast majority of samples are background.
SmoothAPLoss — Smooth Average Precision (Brown et al., 2020)
Approximates AP using sigmoid-based soft rank estimation. For each positive i in the pool:
ŝ_i = 1 + Σ_{j≠i} σ((s_j − s_i) / τ) # soft overall rank
ŝ_i^+ = 1 + Σ_{j≠i, j∈P} σ((s_j − s_i) / τ) # soft rank among positives
AP ≈ (1/|P|) · Σ_{i∈P} ŝ_i^+ / ŝ_i
loss = 1 − AP
Complexity: O(|P|×M) where |P| is the number of positives and M = batch + queue size. At a 0.5% positive rate this is ~200× cheaper than O(M²). Keep M ≤ ~4096.
RecallAtQuantileLoss — Recall at Quantile
Optimizes recall above a score threshold set at the q-th quantile of the pooled score distribution. The threshold is treated as a stop-gradient constant each forward pass:
θ = quantile(scores, 1 − q) [detached — no grad]
soft_recall = (1/|P|) · Σ_{i∈P} σ((s_i − θ) / τ)
loss = 1 − soft_recall
Gradient flows only through positive scores, pushing them above the cutoff. Useful for alert/detection settings (e.g. quantile=0.005 = top 50 bps).
Features
All losses support DDP all-gather via gather_distributed (auto-detected by default).
Focal losses (SigmoidFocalLoss, SoftmaxFocalLoss):
- Arbitrary input shapes —
(N, C),(N, C, L),(N, C, H, W), … ignore_indexmasking — padded positions contribute zero loss and zero gradientmeanreduction divides by valid (non-ignored) count, not total tensor sizemean_positivereduction (softmax only) — normalizes by positive count for detection tasksalpha— scalar (sigmoid) or per-class tensor (softmax) class reweightinglabel_smoothing(softmax only) — forwarded directly toF.cross_entropy
Ranking losses (SmoothAPLoss, RecallAtQuantileLoss):
- Memory queue — circular buffer accumulates past batches to stabilize estimates over small batch sizes; set
queue_size=0to disable - Multi-class — one-vs-rest per class using
logits[:, c] - Binary — set
num_classes=1with targets in{0, 1} - Seq2seq — flatten
[B, T, C]→[B*T, C]upstream before passing - Padding —
ignore_indexrows are excluded from ranking and the positive set - Reductions —
'mean'(default),'sum', or'none'(per-class tensor; degenerate classes arenan) - Per-class logging —
return_per_class=Truereturns(loss, per_class, valid_mask)without a second forward pass
Installation
Requires Python ≥ 3.10 and PyTorch ≥ 2.8.
# from PyPI
pip install imbalanced-losses
# from GitHub (latest dev)
pip install git+https://github.com/chris-santiago/imbalanced-losses.git
# with uv (for development / contributing)
git clone https://github.com/chris-santiago/imbalanced-losses.git
cd imbalanced-losses
uv sync
To run the example scripts, install the optional demo dependencies:
pip install "imbalanced-losses[demo]"
# or with uv:
uv sync --extra demo
Usage
from imbalanced_losses import SmoothAPLoss
from imbalanced_losses import RecallAtQuantileLoss
# Multi-class AP loss
loss_fn = SmoothAPLoss(num_classes=4, queue_size=1024, temperature=0.01)
logits = torch.randn(32, 4) # [N, C] raw logits
targets = torch.randint(0, 4, (32,)) # [N] integer class labels
loss = loss_fn(logits, targets)
loss.backward()
# Recall at top-0.5%
loss_fn = RecallAtQuantileLoss(num_classes=4, quantile=0.005, queue_size=1024)
loss = loss_fn(logits, targets)
loss.backward()
# Binary classification
loss_fn = SmoothAPLoss(num_classes=1, queue_size=256)
logits = torch.randn(32, 1)
targets = torch.randint(0, 2, (32,)) # {0, 1}
loss = loss_fn(logits, targets)
# Per-class logging (e.g. PyTorch Lightning)
loss, per_class, valid = loss_fn(logits, targets, return_per_class=True)
for c in valid.nonzero(as_tuple=True)[0].tolist():
self.log(f"train/ap_loss_class_{c}", per_class[c])
# Seq2seq: flatten upstream
logits = logits.view(-1, C)
targets = targets.view(-1)
loss = loss_fn(logits, targets)
# Reset queue between training and validation
loss_fn.reset_queue()
Parameters
Focal losses
| Parameter | Default | Description |
|---|---|---|
alpha |
0.25 / None |
Pos/neg balance weight in [0,1] or -1 to disable (sigmoid); per-class tensor or None (softmax) |
gamma |
2.0 |
Focusing exponent; 0 recovers vanilla BCE/CE |
reduction |
'none' / 'mean' |
'none', 'mean', 'sum', or 'mean_positive' (softmax only) |
ignore_index |
-100 |
(SoftmaxFocalLoss only) Target value for padding positions |
background_class |
0 |
(SoftmaxFocalLoss only) Class excluded from mean_positive denominator |
label_smoothing |
0.0 |
(SoftmaxFocalLoss only) Forwarded to F.cross_entropy |
gather_distributed |
None |
None = auto-detect DDP; False = always local; True = always gather |
Ranking losses
| Parameter | Default | Description |
|---|---|---|
num_classes |
required | Number of output classes; use 1 for binary |
queue_size |
1024 |
Circular buffer size (rows); 0 to disable |
temperature |
0.01 |
Sigmoid sharpness τ; smaller = sharper gradients |
reduction |
'mean' |
'mean', 'sum', or 'none' |
ignore_index |
-100 |
Target value for padding positions |
update_queue_in_eval |
False |
Allow queue updates during model.eval() |
gather_distributed |
None |
None = auto-detect DDP; False = always local; True = always gather |
quantile |
0.005 |
(RecallAtQuantileLoss only) Top fraction to target |
quantile_interpolation |
'higher' |
(RecallAtQuantileLoss only) torch.quantile interpolation method |
Temperature guidance: 0.005–0.05 is the practical range. Lower values approximate the true discontinuous rank more closely but produce harder gradients.
Queue size guidance: For quantile=0.005 (top 50 bps) you need at least ~200 samples in the pool for a meaningful 99.5th percentile estimate.
LossWarmupWrapper — BCE/CE warmup + loss blending + geometric temperature decay
A wrapper that trains with a standard loss (e.g. CrossEntropyLoss) for a warmup period, optionally blends both losses over a transition period, then switches to the ranking loss with a geometrically decaying temperature schedule.
temp(t) = temp_start × (temp_end / temp_start) ^ (elapsed_steps / temp_decay_steps)
The schedule clock starts at the moment of phase switch, not at training start.
Queue poisoning fix: At the switch point the wrapper automatically calls main_loss.reset_queue() (if available), ensuring the ranking loss never sees stale warmup-era logits.
Blending
blend_epochs adds a linear ramp between warmup and pure AP:
Epoch 0–W-1: warmup_loss only (main_weight = 0)
Epoch W: (1−w)×warmup + w×AP w = 1/(blend_epochs+1)
Epoch W+1: (1−w)×warmup + w×AP w = 2/(blend_epochs+1)
...
Epoch W+B+: main_loss only (main_weight = 1)
With warmup_epochs=2, blend_epochs=2: epochs 2→1/3 AP, 3→2/3 AP, 4+→pure AP.
Usage (PyTorch Lightning)
from imbalanced_losses import SmoothAPLoss
from imbalanced_losses import LossWarmupWrapper
class MyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.loss_fn = LossWarmupWrapper(
warmup_loss=nn.CrossEntropyLoss(),
main_loss=SmoothAPLoss(num_classes=10, queue_size=1024),
warmup_epochs=5,
blend_epochs=2, # gradual transition
temp_start=0.5, # soft at switch — stable gradients
temp_end=0.01, # sharp after schedule — closer to true rank
temp_decay_steps=50_000,
)
def on_train_epoch_start(self):
self.loss_fn.on_train_epoch_start(self.current_epoch)
def on_train_batch_start(self, batch, batch_idx):
self.loss_fn.on_train_batch_start(self.global_step)
def training_step(self, batch, batch_idx):
logits, targets = batch
loss = self.loss_fn(logits, targets)
self.log("train/loss", loss)
self.log("train/main_weight", self.loss_fn.main_weight)
if (t := self.loss_fn.current_temperature) is not None:
self.log("train/temperature", t)
return loss
**kwargs (e.g. return_per_class=True) are forwarded to main_loss only when main_weight == 1.0; silently ignored during warmup and blend phases.
Parameters
| Parameter | Default | Description |
|---|---|---|
warmup_loss |
required | Loss used during warmup; must accept (logits, targets) |
main_loss |
required | Loss used after warmup; must accept (logits, targets, **kwargs) |
warmup_epochs |
required | Epochs to use warmup_loss; 0 to skip warmup entirely |
temp_start |
required | Temperature at phase switch |
temp_end |
required | Temperature after temp_decay_steps steps |
temp_decay_steps |
required | Steps over which to decay temperature |
blend_epochs |
0 |
Epochs to linearly ramp from warmup to main loss; 0 = hard switch |
reset_queue_each_epoch |
False |
Call main_loss.reset_queue() at the start of each main-phase epoch |
Properties / methods
| Description | |
|---|---|
in_warmup |
True while epoch < warmup_epochs |
in_blend |
True during the blend_epochs transition period |
main_weight |
Current main loss weight: 0.0 during warmup, linear ramp during blend, 1.0 after |
current_temperature |
Current main_loss.temperature, or None if unavailable |
on_train_epoch_start(epoch) |
Advance epoch counter; detect phase switch; optionally reset queue |
on_train_batch_start(global_step) |
Latch switch_step on first main-phase batch; reset queue; update temperature |
Distributed Training (DDP)
All losses support DDP via built-in all-gather, but globally-correct computation is especially critical for rank-based losses. The imbalanced_losses.distributed module provides two all-gather helpers that handle this correctly.
Why this matters
In DDP each GPU sees only N/world_size samples. The soft-rank computation in SmoothAPLoss and the quantile threshold in RecallAtQuantileLoss become noisy or biased when computed on a shard. For SoftmaxFocalLoss with mean_positive reduction, the positive count in the denominator is similarly unreliable when positives are rare and unevenly distributed across ranks. Gathering logits and targets across all workers before passing them to the loss fixes this for all three cases.
Helpers
| Function | Description |
|---|---|
all_gather_with_grad(tensor) |
Gathers tensors across all workers; preserves gradients for the local rank's slice so autograd works correctly |
all_gather_no_grad(tensor) |
Gathers tensors without gradient tracking; use for integer targets/labels |
all_gather_with_grad replaces the local rank's slice in the output with the original tensor (restoring the gradient connection), while other workers' slices remain detached — matching standard DDP semantics where each worker optimizes its own parameters via all-reduced gradients.
Queue synchronization: Because every worker calls all_gather before passing to the loss, every worker enqueues the same global-batch data. No extra synchronization of the memory queue is needed.
Usage
from imbalanced_losses import SmoothAPLoss
from imbalanced_losses.distributed import all_gather_with_grad, all_gather_no_grad
loss_fn = SmoothAPLoss(num_classes=4, queue_size=1024)
# Inside training_step on each GPU:
logits_global = all_gather_with_grad(logits) # [world_size * N, C] — grad flows
targets_global = all_gather_no_grad(targets) # [world_size * N] — no grad
loss = loss_fn(logits_global, targets_global)
loss.backward()
Both helpers raise RuntimeError if torch.distributed is not available or not initialized. They are no-ops (return the input unchanged) when world_size == 1.
PyTorch Lightning (DDP)
from imbalanced_losses import SmoothAPLoss, LossWarmupWrapper
from imbalanced_losses.distributed import all_gather_with_grad, all_gather_no_grad
class MyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.loss_fn = SmoothAPLoss(num_classes=4, queue_size=1024)
def training_step(self, batch, batch_idx):
logits, targets = batch
logits_g = all_gather_with_grad(logits)
targets_g = all_gather_no_grad(targets)
loss = self.loss_fn(logits_g, targets_g)
return loss
Examples
Require the demo extras:
uv sync --extra demo
# or: pip install scikit-learn
toy_demo.py — single-run trace
Trains one model (warmup → blend → AP) and prints epoch-by-epoch phase, main_weight, temperature, loss, and AUCPR.
python examples/toy_demo.py # default: 3 warmup + 2 blend epochs
python examples/toy_demo.py --blend-epochs 0 # hard switch (no blend)
python examples/toy_demo.py --pos-rate 0.05 # easier problem
focal_demo.py — BCE vs focal loss comparison
Trains four models on the same imbalanced data and prints per-epoch AUCPR:
| Strategy | Description |
|---|---|
| BCE | Vanilla BCEWithLogitsLoss; easy negatives dominate |
| BCE+weight | BCEWithLogitsLoss with pos_weight = n_neg/n_pos |
| focal α γ | SigmoidFocalLoss(alpha=0.25, gamma=2) — RetinaNet defaults |
| focal γ only | SigmoidFocalLoss(alpha=-1, gamma=2) — focusing only, no alpha |
python examples/focal_demo.py
python examples/focal_demo.py --pos-rate 0.02 # easier problem
python examples/focal_demo.py --gamma 5 --alpha 0.5
compare_demo.py — side-by-side comparison
Trains three models on the same data and seed and prints a per-epoch AUCPR table:
| Strategy | Description |
|---|---|
| warmup-only | BCE for all epochs; never switches to AP |
| AP-only | SmoothAPLoss from epoch 0, no warmup |
| warmup+blend | BCE warmup → linear blend → pure SmoothAPLoss |
python examples/compare_demo.py
python examples/compare_demo.py --pos-rate 0.05
python examples/compare_demo.py --warmup-epochs 5 --blend-epochs 3
Key flags (both scripts): --pos-rate, --warmup-epochs, --blend-epochs, --total-epochs, --batch-size, --queue-size, --temp-start, --temp-end, --lr, --seed.
Tests
pytest tests/ -v
References
Lin, T.-Y., Goyal, P., Girshick, R., He, K., & Dollár, P. (2017). Focal Loss for Dense Object Detection. ICCV 2017.
Brown, A., Xie, W., Kalogeiton, V., & Zisserman, A. (2020). Smooth-AP: Smoothing the Path Towards Large-Scale Image Retrieval. ECCV 2020.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file imbalanced_losses-0.2.3.tar.gz.
File metadata
- Download URL: imbalanced_losses-0.2.3.tar.gz
- Upload date:
- Size: 190.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e154be0192aa3ee3c7b71de6fe47033b833d92ec71e34c1a1ec083a5e1bb1890
|
|
| MD5 |
028238096c8855277cb8703e60a952f2
|
|
| BLAKE2b-256 |
a6b169c120e9c0144cd61ee44db9b5f6d95e3a87208cc3ff48f367105b76f47e
|
Provenance
The following attestation bundles were made for imbalanced_losses-0.2.3.tar.gz:
Publisher:
publish.yaml on chris-santiago/imbalanced-losses
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
imbalanced_losses-0.2.3.tar.gz -
Subject digest:
e154be0192aa3ee3c7b71de6fe47033b833d92ec71e34c1a1ec083a5e1bb1890 - Sigstore transparency entry: 1399127237
- Sigstore integration time:
-
Permalink:
chris-santiago/imbalanced-losses@7cde0421bbd36d7b2194a13a31a4732fdc1ac36c -
Branch / Tag:
refs/tags/v0.2.3 - Owner: https://github.com/chris-santiago
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yaml@7cde0421bbd36d7b2194a13a31a4732fdc1ac36c -
Trigger Event:
release
-
Statement type:
File details
Details for the file imbalanced_losses-0.2.3-py3-none-any.whl.
File metadata
- Download URL: imbalanced_losses-0.2.3-py3-none-any.whl
- Upload date:
- Size: 31.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5af1ee52eef0a8d0db1d10929ed8bf8df70ab811bb2cd41a1320113363c6344e
|
|
| MD5 |
08afdb30105cce95794b3482fda0a825
|
|
| BLAKE2b-256 |
8776a8199046284fba2ece780fa85a1055e8d7e5e194ec3da3894ebb33521533
|
Provenance
The following attestation bundles were made for imbalanced_losses-0.2.3-py3-none-any.whl:
Publisher:
publish.yaml on chris-santiago/imbalanced-losses
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
imbalanced_losses-0.2.3-py3-none-any.whl -
Subject digest:
5af1ee52eef0a8d0db1d10929ed8bf8df70ab811bb2cd41a1320113363c6344e - Sigstore transparency entry: 1399127245
- Sigstore integration time:
-
Permalink:
chris-santiago/imbalanced-losses@7cde0421bbd36d7b2194a13a31a4732fdc1ac36c -
Branch / Tag:
refs/tags/v0.2.3 - Owner: https://github.com/chris-santiago
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yaml@7cde0421bbd36d7b2194a13a31a4732fdc1ac36c -
Trigger Event:
release
-
Statement type: