Skip to main content

PyTorch classification loss with a SmoothMax-based calibration penalty

Project description

FairLoss

PyTorch classification loss that combines cross-entropy with a smooth argmax penalty. Inspired by Penalized Logarithmic Loss (PLL) from Ahmadian et al. (2025).

The problem

In multi-class classification, we often want losses to rank correct predictions above incorrect ones. Strictly proper scores such as cross-entropy encourage calibrated probabilities, but they do not always do that.

Consider two logits with true class 0:

Sample Logits True class Argmax correct?
1 [0.34, 0.33, 0.33] 0 Yes
2 [0.49, 0.51, 0.00] 0 No

Sample 2 assigns a higher logit to the true class but picks the wrong label. Cross-entropy prefers sample 2; intuitively, sample 1 should win.

Ahmadian et al. call this the Superior property: a scoring rule should always rank correct predictions above incorrect ones. PLL adds a hard argmax penalty while staying strictly proper. FairLoss applies the same idea with a differentiable penalty suitable for gradient-based training.

Running examples/difference.py on these logits:

Cross-entropy:  tensor([1.0920, 0.9681])   # sample 2 wins
FairLoss:       tensor([1.0920, 1.3014])   # sample 1 wins

Cross-entropy gives sample 2 the lower loss because it is more confident on the true class. FairLoss adds a penalty when the argmax is wrong, so sample 1 ends up with the lower loss. Sample 1 keeps the same loss under both criteria because its prediction is already correct.

How it works

FairLoss keeps standard cross-entropy and adds a SmoothMax-based penalty:

total_loss = cross_entropy + penalty

The penalty is built from the gap between the true-class logit and a smooth approximation of the largest logit:

  1. SmoothMax approximates max(logits) — the predicted class.
  2. logit_true - smooth_max is near zero when the argmax is correct, negative when it is wrong.
  3. A sigmoid transform maps that gap into a smooth penalty: near zero for correct predictions, positive for incorrect ones.

Full derivation: formulation.pdf

Install

pip install pytorch-fairloss

Import as from fairloss import FairLoss (the PyPI name differs from the import path).

Or install from source:

pip install git+https://github.com/stormaref/FairLoss.git

Quick start

import torch
import torch.nn as nn
import torch.optim as optim
from fairloss import FairLoss

model = nn.Linear(784, 10)
criterion = FairLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

x = torch.randn(64, 784)
y = torch.randint(0, 10, (64,))

logits = model(x)
loss = criterion(logits, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

See examples/basic_usage.py for a training loop and examples/difference.py for the cross-entropy comparison above.

API

FairLoss(cross_entropy=None, smooth_max_beta=1000.0, scaling_factor=1000.0, reduction="mean")

Wraps a nn.CrossEntropyLoss and adds the FairLoss penalty on top.

Parameter Description
cross_entropy Optional nn.CrossEntropyLoss with reduction="none". Defaults to nn.CrossEntropyLoss(reduction="none"). Pass a custom instance to set weight, label_smoothing, ignore_index, etc.
smooth_max_beta Temperature for the smooth maximum approximation over logits.
scaling_factor Scale applied before the sigmoid penalty transform.
reduction Reduction applied to the combined loss: "mean", "sum", or "none". Default: "mean".

Forward pass

  • output: Logits tensor of shape (N, C) where N is batch size and C is number of classes.
  • target: Integer class labels of shape (N,).
  • Returns: Scalar mean loss over the batch (or per-sample losses when reduction="none").

Requirements

  • Python 3.9+
  • PyTorch 2.0+

License

MIT — see LICENSE.

Reference

Ahmadian, R., Ghatee, M., & Wahlström, J. (2025). Superior scoring rules for probabilistic evaluation of single-label multi-class classification tasks. International Journal of Approximate Reasoning, 182, 109421. https://doi.org/10.1016/j.ijar.2025.109421

Preprint: arXiv:2407.17697

GitHub: https://github.com/stormaref/FairLoss

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_fairloss-0.1.1.tar.gz (58.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_fairloss-0.1.1-py3-none-any.whl (5.5 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_fairloss-0.1.1.tar.gz.

File metadata

  • Download URL: pytorch_fairloss-0.1.1.tar.gz
  • Upload date:
  • Size: 58.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for pytorch_fairloss-0.1.1.tar.gz
Algorithm Hash digest
SHA256 6a9b27a350fe068fb1a59f2afc32e8c4cba2ce3c3de92ce66f15705a03e61275
MD5 3a13008019986077d656cae077b9210c
BLAKE2b-256 688b9f862edc0001c79b4410feef138840b4826bc375ca1a7d3e37cd9e0164df

See more details on using hashes here.

File details

Details for the file pytorch_fairloss-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_fairloss-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 2f75e26ecccbab19bebbcf8745ef70bd1597fc0d2a24afc5055596d1d7a7a48c
MD5 2ad4e6c3dc0d0e36669206a4b65ccbf3
BLAKE2b-256 d2fcdfe81bba86fa4f49e3f96261a523c3784452b830378a594a26ba129adb7b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page