Skip to main content

BiTempered Loss for training under noisy labels.

Project description

Bi-Tempered Logistic Loss

Overview of the method is here: Google AI Blogpost

Also, explore the interactive visualization that demonstrates the practical properties of the Bi-Tempered logistic loss.

Bi-Tempered logistic loss is a generalized softmax cross-entropy loss function with bounded loss value per sample and a heavy-tail softmax probability function.

Bi-tempered loss generalizes (with a bias correction term):

  • Zhang & Sabuncu. "Generalized cross entropy loss for training deep neural networks with noisy labels." In NeurIPS 2018.

which is recovered when 0.0 <= t1 <= 1.0 and t2 = 1.0. It also includes:

  • Ding & Vishwanathan. "t-Logistic regression." In NeurIPS 2010.

for t1 = 1.0 and t2 >= 1.0.

Bi-tempered loss is equal to the softmax cross entropy loss when t1 = t2 = 1.0. For 0.0 <= t1 < 1.0 and t2 > 1.0, bi-tempered loss provides a more robust alternative to the cross entropy loss for handling label noise and outliers.

TensorFlow and JAX

A replacement for standard logistic loss function: tf.losses.softmax_cross_entropy is available here

def bi_tempered_logistic_loss(activations,
                              labels,
                              t1,
                              t2,
                              label_smoothing=0.0,
                              num_iters=5):
  """Bi-Tempered Logistic Loss with custom gradient.
  Args:
    activations: A multi-dimensional tensor with last dimension `num_classes`.
    labels: A tensor with shape and dtype as activations.
    t1: Temperature 1 (< 1.0 for boundedness).
    t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
    label_smoothing: Label smoothing parameter between [0, 1).
    num_iters: Number of iterations to run the method.
  Returns:
    A loss tensor.
  """

Replacements are also available for the transfer functions:

Tempered version of tf.nn.sigmoid and jax.nn.sigmoid:

def tempered_sigmoid(activations, t, num_iters=5):
  """Tempered sigmoid function.
  Args:
    activations: Activations for the positive class for binary classification.
    t: Temperature > 0.0.
    num_iters: Number of iterations to run the method.
  Returns:
    A probabilities tensor.
  """

Tempered version of tf.nn.softmax and jax.nn.softmax:

def tempered_softmax(activations, t, num_iters=5):
  """Tempered softmax function.
  Args:
    activations: A multi-dimensional tensor with last dimension `num_classes`.
    t: Temperature > 0.0.
    num_iters: Number of iterations to run the method.
  Returns:
    A probabilities tensor.
  """

Citation

When referencing Bi-Tempered loss, cite this paper:

@inproceedings{amid2019robust,
  title={Robust bi-tempered logistic loss based on bregman divergences},
  author={Amid, Ehsan and Warmuth, Manfred KK and Anil, Rohan and Koren, Tomer},
  booktitle={Advances in Neural Information Processing Systems},
  pages={15013--15022},
  year={2019}
}

Contributions

We are eager to collaborate with you too! Please send us a pull request or open a github issue. Please see this doc for further details

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax-bitempered-loss-0.0.2.tar.gz (11.8 kB view details)

Uploaded Source

Built Distribution

jax_bitempered_loss-0.0.2-py3-none-any.whl (12.2 kB view details)

Uploaded Python 3

File details

Details for the file jax-bitempered-loss-0.0.2.tar.gz.

File metadata

  • Download URL: jax-bitempered-loss-0.0.2.tar.gz
  • Upload date:
  • Size: 11.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.8.3 pkginfo/1.8.2 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.9

File hashes

Hashes for jax-bitempered-loss-0.0.2.tar.gz
Algorithm Hash digest
SHA256 fd4bb6bdbd63fc9b3cf15ca3282e992214905941a3b79e6c2e7c9eb018edb82e
MD5 b96f4e02739c9fcdf5c6a5842815a296
BLAKE2b-256 bda1d7d508de3af6c915b5c33a9dc3fbda7321b75e5c62fd291dc89d30838b1b

See more details on using hashes here.

File details

Details for the file jax_bitempered_loss-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: jax_bitempered_loss-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 12.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.8.3 pkginfo/1.8.2 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.9

File hashes

Hashes for jax_bitempered_loss-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 480c8b80abcb00883b001fbdd6f3f70b27681bd5e204fba6669818c0f9a326f1
MD5 141cafdb8ecaae17a255da6737607c91
BLAKE2b-256 a3010da045844722971066d9d53a9b2e31786cf5048e3b2707fd774566e5ece9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page