proxy-losses is a PyTorch library of differentiable proxy losses for ranking metrics — intended as drop-in replacements for cross-entropy when the real objective is Average Precision or recall at a specific operating point.
This project has been archived.
The maintainers of this project have marked this project as archived. No new releases are expected.
Project description
proxy-losses
proxy-losses is a PyTorch library of differentiable proxy losses for ranking metrics — intended as drop-in replacements for cross-entropy when the real objective is Average Precision or recall at a specific operating point.
What's in it:
SmoothAPLoss— Differentiable approximation of AP (Brown et al., ECCV 2020). Uses sigmoid-based soft rank estimation; O(M²) in pool size. Supports multi-class, binary, and seq2seq settings.RecallAtQuantileLoss— Optimizes recall above a score threshold set at the q-th quantile of the pooled distribution. Useful for alert/detection workloads (e.g. top 0.5% of scores).LossWarmupWrapper— Training utility that runs a standard loss (BCE/CE) during warmup, linearly blends into the ranking loss over a configurable transition window, then applies geometric temperature decay. Automatically resets the memory queue at the phase switch to prevent queue poisoning from warmup-era logits.
Design points:
- Circular memory queue stabilizes gradient estimates across small batches — critical at low positive rates (e.g. 0.5%)
- Compatible with PyTorch Lightning via
on_train_epoch_start/on_train_batch_starthooks toy_demo.pydemonstrates the full warmup→blend→AP pipeline on a highly imbalanced binary classification task using sklearn's `make_classification
Losses
SmoothAPLoss — Smooth Average Precision (Brown et al., 2020)
Approximates AP using sigmoid-based soft rank estimation. For each positive i in the pool:
ŝ_i = 1 + Σ_{j≠i} σ((s_j − s_i) / τ) # soft overall rank
ŝ_i^+ = 1 + Σ_{j≠i, j∈P} σ((s_j − s_i) / τ) # soft rank among positives
AP ≈ (1/|P|) · Σ_{i∈P} ŝ_i^+ / ŝ_i
loss = 1 − AP
Complexity: O(M²) in memory and compute where M = batch + queue size. Keep M ≤ ~4096.
RecallAtQuantileLoss — Recall at Quantile
Optimizes recall above a score threshold set at the q-th quantile of the pooled score distribution. The threshold is treated as a stop-gradient constant each forward pass:
θ = quantile(scores, 1 − q) [detached — no grad]
soft_recall = (1/|P|) · Σ_{i∈P} σ((s_i − θ) / τ)
loss = 1 − soft_recall
Gradient flows only through positive scores, pushing them above the cutoff. Useful for alert/detection settings (e.g. quantile=0.005 = top 50 bps).
Features
Both losses share the same interface and design:
- Memory queue — circular buffer accumulates past batches to stabilize estimates over small batch sizes; set
queue_size=0to disable - Multi-class — one-vs-rest per class using
logits[:, c] - Binary — set
num_classes=1with targets in{0, 1} - Seq2seq — flatten
[B, T, C]→[B*T, C]upstream before passing - Padding —
ignore_indexrows are excluded from ranking and the positive set - Reductions —
'mean'(default),'sum', or'none'(per-class tensor; degenerate classes arenan) - Per-class logging —
return_per_class=Truereturns(loss, per_class, valid_mask)without a second forward pass
Installation
Requires Python ≥ 3.10 and PyTorch ≥ 2.10.
# with uv
uv sync
# or pip
pip install torch>=2.10
Usage
from proxy_losses import SmoothAPLoss
from proxy_losses import RecallAtQuantileLoss
# Multi-class AP loss
loss_fn = SmoothAPLoss(num_classes=4, queue_size=1024, temperature=0.01)
logits = torch.randn(32, 4) # [N, C] raw logits
targets = torch.randint(0, 4, (32,)) # [N] integer class labels
loss = loss_fn(logits, targets)
loss.backward()
# Recall at top-0.5%
loss_fn = RecallAtQuantileLoss(num_classes=4, quantile=0.005, queue_size=1024)
loss = loss_fn(logits, targets)
loss.backward()
# Binary classification
loss_fn = SmoothAPLoss(num_classes=1, queue_size=256)
logits = torch.randn(32, 1)
targets = torch.randint(0, 2, (32,)) # {0, 1}
loss = loss_fn(logits, targets)
# Per-class logging (e.g. PyTorch Lightning)
loss, per_class, valid = loss_fn(logits, targets, return_per_class=True)
for c in valid.nonzero(as_tuple=True)[0].tolist():
self.log(f"train/ap_loss_class_{c}", per_class[c])
# Seq2seq: flatten upstream
logits = logits.view(-1, C)
targets = targets.view(-1)
loss = loss_fn(logits, targets)
# Reset queue between training and validation
loss_fn.reset_queue()
Parameters
| Parameter | Default | Description |
|---|---|---|
num_classes |
required | Number of output classes; use 1 for binary |
queue_size |
1024 |
Circular buffer size (rows); 0 to disable |
temperature |
0.01 |
Sigmoid sharpness τ; smaller = sharper gradients |
reduction |
'mean' |
'mean', 'sum', or 'none' |
ignore_index |
-100 |
Target value for padding positions |
update_queue_in_eval |
False |
Allow queue updates during model.eval() |
quantile |
0.005 |
(RecallAtQuantileLoss only) Top fraction to target |
quantile_interpolation |
'higher' |
(RecallAtQuantileLoss only) torch.quantile interpolation method |
Temperature guidance: 0.005–0.05 is the practical range. Lower values approximate the true discontinuous rank more closely but produce harder gradients.
Queue size guidance: For quantile=0.005 (top 50 bps) you need at least ~200 samples in the pool for a meaningful 99.5th percentile estimate.
LossWarmupWrapper — BCE/CE warmup + loss blending + geometric temperature decay
A wrapper that trains with a standard loss (e.g. CrossEntropyLoss) for a warmup period, optionally blends both losses over a transition period, then switches to the ranking loss with a geometrically decaying temperature schedule.
temp(t) = temp_start × (temp_end / temp_start) ^ (elapsed_steps / temp_decay_steps)
The schedule clock starts at the moment of phase switch, not at training start.
Queue poisoning fix: At the switch point the wrapper automatically calls main_loss.reset_queue() (if available), ensuring the ranking loss never sees stale warmup-era logits.
Blending
blend_epochs adds a linear ramp between warmup and pure AP:
Epoch 0–W-1: warmup_loss only (ap_weight = 0)
Epoch W: (1−w)×warmup + w×AP w = 1/(blend_epochs+1)
Epoch W+1: (1−w)×warmup + w×AP w = 2/(blend_epochs+1)
...
Epoch W+B+: main_loss only (ap_weight = 1)
With warmup_epochs=2, blend_epochs=2: epochs 2→1/3 AP, 3→2/3 AP, 4+→pure AP.
Usage (PyTorch Lightning)
from proxy_losses import SmoothAPLoss
from proxy_losses import LossWarmupWrapper
class MyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.loss_fn = LossWarmupWrapper(
warmup_loss=nn.CrossEntropyLoss(),
main_loss=SmoothAPLoss(num_classes=10, queue_size=1024),
warmup_epochs=5,
blend_epochs=2, # gradual transition
temp_start=0.5, # soft at switch — stable gradients
temp_end=0.01, # sharp after schedule — closer to true rank
temp_decay_steps=50_000,
)
def on_train_epoch_start(self):
self.loss_fn.on_train_epoch_start(self.current_epoch)
def on_train_batch_start(self, batch, batch_idx):
self.loss_fn.on_train_batch_start(self.global_step)
def training_step(self, batch, batch_idx):
logits, targets = batch
loss = self.loss_fn(logits, targets)
self.log("train/loss", loss)
self.log("train/ap_weight", self.loss_fn.ap_weight)
if (t := self.loss_fn.current_temperature) is not None:
self.log("train/temperature", t)
return loss
**kwargs (e.g. return_per_class=True) are forwarded to main_loss only when ap_weight == 1.0; silently ignored during warmup and blend phases.
Parameters
| Parameter | Default | Description |
|---|---|---|
warmup_loss |
required | Loss used during warmup; must accept (logits, targets) |
main_loss |
required | Loss used after warmup; must accept (logits, targets, **kwargs) |
warmup_epochs |
required | Epochs to use warmup_loss; 0 to skip warmup entirely |
temp_start |
required | Temperature at phase switch |
temp_end |
required | Temperature after temp_decay_steps steps |
temp_decay_steps |
required | Steps over which to decay temperature |
blend_epochs |
0 |
Epochs to linearly ramp from warmup to main loss; 0 = hard switch |
reset_queue_each_epoch |
False |
Call main_loss.reset_queue() at the start of each main-phase epoch |
Properties / methods
| Description | |
|---|---|
in_warmup |
True while epoch < warmup_epochs |
in_blend |
True during the blend_epochs transition period |
ap_weight |
Current AP loss weight: 0.0 during warmup, linear ramp during blend, 1.0 after |
current_temperature |
Current main_loss.temperature, or None if unavailable |
on_train_epoch_start(epoch) |
Advance epoch counter; detect phase switch; optionally reset queue |
on_train_batch_start(global_step) |
Latch switch_step on first main-phase batch; reset queue; update temperature |
Toy demo
toy_demo.py trains a small MLP on an imbalanced binary classification task (default: 0.5% positive rate) using make_classification from scikit-learn. It prints epoch-by-epoch AUCPR so you can see whether the warmup→blend→AP transition helps or hurts.
Requires the demo extras:
uv sync --extra demo
# or: pip install scikit-learn
# Default: 3 warmup + 2 blend epochs, then pure AP
python examples/toy_demo.py
# Hard switch for comparison
python examples/toy_demo.py --blend-epochs 0
# Easier problem
python examples/toy_demo.py --pos-rate 0.05
Key flags: --pos-rate, --warmup-epochs, --blend-epochs, --total-epochs, --batch-size, --queue-size, --temp-start, --temp-end, --lr, --seed.
Tests
pytest tests/ -v
References
Brown, A., Xie, W., Kalogeiton, V., & Zisserman, A. (2020). Smooth-AP: Smoothing the Path Towards Large-Scale Image Retrieval. ECCV 2020.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file proxy_losses-0.1.1.tar.gz.
File metadata
- Download URL: proxy_losses-0.1.1.tar.gz
- Upload date:
- Size: 85.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8205fba38d88b12329d14f48d9c8678ad34b2b601c887504df40980dc5fbc0a4
|
|
| MD5 |
399fa06e5015081fb2576c286d588729
|
|
| BLAKE2b-256 |
ffdd172299c735d5e3880d583b2a62ab33947998f0b3b1ed845986302789c2d8
|
Provenance
The following attestation bundles were made for proxy_losses-0.1.1.tar.gz:
Publisher:
publish.yaml on chris-santiago/proxy-losses
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
proxy_losses-0.1.1.tar.gz -
Subject digest:
8205fba38d88b12329d14f48d9c8678ad34b2b601c887504df40980dc5fbc0a4 - Sigstore transparency entry: 1091432771
- Sigstore integration time:
-
Permalink:
chris-santiago/proxy-losses@8eb6d01bac0bc21984d7a9ab3f68794e936b6e10 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/chris-santiago
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yaml@8eb6d01bac0bc21984d7a9ab3f68794e936b6e10 -
Trigger Event:
workflow_dispatch
-
Statement type:
File details
Details for the file proxy_losses-0.1.1-py3-none-any.whl.
File metadata
- Download URL: proxy_losses-0.1.1-py3-none-any.whl
- Upload date:
- Size: 19.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
13b53b9b8ebd072fdf38f2206c645debc01483bd9dee588ed6af0dd6675b1a75
|
|
| MD5 |
3747012c5174de47acb10347afe234e2
|
|
| BLAKE2b-256 |
5e594610a56b409ba3d567ca2eeadb323d82e779837575dd59511b78f3c22417
|
Provenance
The following attestation bundles were made for proxy_losses-0.1.1-py3-none-any.whl:
Publisher:
publish.yaml on chris-santiago/proxy-losses
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
proxy_losses-0.1.1-py3-none-any.whl -
Subject digest:
13b53b9b8ebd072fdf38f2206c645debc01483bd9dee588ed6af0dd6675b1a75 - Sigstore transparency entry: 1091432774
- Sigstore integration time:
-
Permalink:
chris-santiago/proxy-losses@8eb6d01bac0bc21984d7a9ab3f68794e936b6e10 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/chris-santiago
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yaml@8eb6d01bac0bc21984d7a9ab3f68794e936b6e10 -
Trigger Event:
workflow_dispatch
-
Statement type: