PyTorch classification loss with a SmoothMax-based calibration penalty
Project description
FairLoss
PyTorch classification loss that combines cross-entropy with a smooth argmax penalty. Inspired by Penalized Logarithmic Loss (PLL) from Ahmadian et al. (2025).
The problem
In multi-class classification, we often want losses to rank correct predictions above incorrect ones. Strictly proper scores such as cross-entropy encourage calibrated probabilities, but they do not always do that.
Consider two logits with true class 0:
| Sample | Logits | True class | Argmax correct? |
|---|---|---|---|
| 1 | [0.34, 0.33, 0.33] |
0 | Yes |
| 2 | [0.49, 0.51, 0.00] |
0 | No |
Sample 2 assigns a higher logit to the true class but picks the wrong label. Cross-entropy prefers sample 2; intuitively, sample 1 should win.
Ahmadian et al. call this the Superior property: a scoring rule should always rank correct predictions above incorrect ones. PLL adds a hard argmax penalty while staying strictly proper. FairLoss applies the same idea with a differentiable penalty suitable for gradient-based training.
Running examples/difference.py on these logits:
Cross-entropy: tensor([1.0920, 0.9681]) # sample 2 wins
FairLoss: tensor([1.0920, 1.3014]) # sample 1 wins
Cross-entropy gives sample 2 the lower loss because it is more confident on the true class. FairLoss adds a penalty when the argmax is wrong, so sample 1 ends up with the lower loss. Sample 1 keeps the same loss under both criteria because its prediction is already correct.
How it works
FairLoss keeps standard cross-entropy and adds a SmoothMax-based penalty:
total_loss = cross_entropy + penalty
The penalty is built from the gap between the true-class logit and a smooth approximation of the largest logit:
- SmoothMax approximates
max(logits)— the predicted class. logit_true - smooth_maxis near zero when the argmax is correct, negative when it is wrong.- A sigmoid transform maps that gap into a smooth penalty: near zero for correct predictions, positive for incorrect ones.
Full derivation: formulation.pdf
Install
pip install pytorch-fairloss
Import as from fairloss import FairLoss (the PyPI name differs from the import path).
Or install from source:
pip install git+https://github.com/stormaref/FairLoss.git
Quick start
import torch
import torch.nn as nn
import torch.optim as optim
from fairloss import FairLoss
model = nn.Linear(784, 10)
criterion = FairLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
x = torch.randn(64, 784)
y = torch.randint(0, 10, (64,))
logits = model(x)
loss = criterion(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
See examples/basic_usage.py for a training loop and examples/difference.py for the cross-entropy comparison above.
API
FairLoss(cross_entropy=None, smooth_max_beta=1000.0, scaling_factor=1000.0, reduction="mean")
Wraps a nn.CrossEntropyLoss and adds the FairLoss penalty on top.
| Parameter | Description |
|---|---|
cross_entropy |
Optional nn.CrossEntropyLoss with reduction="none". Defaults to nn.CrossEntropyLoss(reduction="none"). Pass a custom instance to set weight, label_smoothing, ignore_index, etc. |
smooth_max_beta |
Temperature for the smooth maximum approximation over logits. |
scaling_factor |
Scale applied before the sigmoid penalty transform. |
reduction |
Reduction applied to the combined loss: "mean", "sum", or "none". Default: "mean". |
Forward pass
output: Logits tensor of shape(N, C)whereNis batch size andCis number of classes.target: Integer class labels of shape(N,).- Returns: Scalar mean loss over the batch (or per-sample losses when
reduction="none").
Requirements
- Python 3.9+
- PyTorch 2.0+
License
MIT — see LICENSE.
Reference
Ahmadian, R., Ghatee, M., & Wahlström, J. (2025). Superior scoring rules for probabilistic evaluation of single-label multi-class classification tasks. International Journal of Approximate Reasoning, 182, 109421. https://doi.org/10.1016/j.ijar.2025.109421
Preprint: arXiv:2407.17697
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pytorch_fairloss-0.1.1.tar.gz.
File metadata
- Download URL: pytorch_fairloss-0.1.1.tar.gz
- Upload date:
- Size: 58.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.15
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6a9b27a350fe068fb1a59f2afc32e8c4cba2ce3c3de92ce66f15705a03e61275
|
|
| MD5 |
3a13008019986077d656cae077b9210c
|
|
| BLAKE2b-256 |
688b9f862edc0001c79b4410feef138840b4826bc375ca1a7d3e37cd9e0164df
|
File details
Details for the file pytorch_fairloss-0.1.1-py3-none-any.whl.
File metadata
- Download URL: pytorch_fairloss-0.1.1-py3-none-any.whl
- Upload date:
- Size: 5.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.15
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2f75e26ecccbab19bebbcf8745ef70bd1597fc0d2a24afc5055596d1d7a7a48c
|
|
| MD5 |
2ad4e6c3dc0d0e36669206a4b65ccbf3
|
|
| BLAKE2b-256 |
d2fcdfe81bba86fa4f49e3f96261a523c3784452b830378a594a26ba129adb7b
|