BiTempered Loss for training under noisy labels.
Project description
Bi-Tempered Logistic Loss
Overview of the method is here: Google AI Blogpost
Also, explore the interactive visualization that demonstrates the practical properties of the Bi-Tempered logistic loss.
Bi-Tempered logistic loss is a generalized softmax cross-entropy loss function with bounded loss value per sample and a heavy-tail softmax probability function.
Bi-tempered loss generalizes (with a bias correction term):
- Zhang & Sabuncu. "Generalized cross entropy loss for training deep neural networks with noisy labels." In NeurIPS 2018.
which is recovered when 0.0 <= t1 <= 1.0 and t2 = 1.0. It also includes:
- Ding & Vishwanathan. "t-Logistic regression." In NeurIPS 2010.
for t1 = 1.0 and t2 >= 1.0.
Bi-tempered loss is equal to the softmax cross entropy loss when t1 = t2 = 1.0. For 0.0 <= t1 < 1.0 and t2 > 1.0, bi-tempered loss provides a more robust alternative to the cross entropy loss for handling label noise and outliers.
TensorFlow and JAX
A replacement for standard logistic loss function: tf.losses.softmax_cross_entropy
is available here
def bi_tempered_logistic_loss(activations,
labels,
t1,
t2,
label_smoothing=0.0,
num_iters=5):
"""Bi-Tempered Logistic Loss with custom gradient.
Args:
activations: A multi-dimensional tensor with last dimension `num_classes`.
labels: A tensor with shape and dtype as activations.
t1: Temperature 1 (< 1.0 for boundedness).
t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
label_smoothing: Label smoothing parameter between [0, 1).
num_iters: Number of iterations to run the method.
Returns:
A loss tensor.
"""
Replacements are also available for the transfer functions:
Tempered version of tf.nn.sigmoid and jax.nn.sigmoid:
def tempered_sigmoid(activations, t, num_iters=5):
"""Tempered sigmoid function.
Args:
activations: Activations for the positive class for binary classification.
t: Temperature > 0.0.
num_iters: Number of iterations to run the method.
Returns:
A probabilities tensor.
"""
Tempered version of tf.nn.softmax and jax.nn.softmax:
def tempered_softmax(activations, t, num_iters=5):
"""Tempered softmax function.
Args:
activations: A multi-dimensional tensor with last dimension `num_classes`.
t: Temperature > 0.0.
num_iters: Number of iterations to run the method.
Returns:
A probabilities tensor.
"""
Citation
When referencing Bi-Tempered loss, cite this paper:
@inproceedings{amid2019robust,
title={Robust bi-tempered logistic loss based on bregman divergences},
author={Amid, Ehsan and Warmuth, Manfred KK and Anil, Rohan and Koren, Tomer},
booktitle={Advances in Neural Information Processing Systems},
pages={15013--15022},
year={2019}
}
Contributions
We are eager to collaborate with you too! Please send us a pull request or open a github issue. Please see this doc for further details
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file jax-bitempered-loss-0.0.2.tar.gz
.
File metadata
- Download URL: jax-bitempered-loss-0.0.2.tar.gz
- Upload date:
- Size: 11.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.8.3 pkginfo/1.8.2 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fd4bb6bdbd63fc9b3cf15ca3282e992214905941a3b79e6c2e7c9eb018edb82e |
|
MD5 | b96f4e02739c9fcdf5c6a5842815a296 |
|
BLAKE2b-256 | bda1d7d508de3af6c915b5c33a9dc3fbda7321b75e5c62fd291dc89d30838b1b |
File details
Details for the file jax_bitempered_loss-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: jax_bitempered_loss-0.0.2-py3-none-any.whl
- Upload date:
- Size: 12.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.8.3 pkginfo/1.8.2 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 480c8b80abcb00883b001fbdd6f3f70b27681bd5e204fba6669818c0f9a326f1 |
|
MD5 | 141cafdb8ecaae17a255da6737607c91 |
|
BLAKE2b-256 | a3010da045844722971066d9d53a9b2e31786cf5048e3b2707fd774566e5ece9 |