Uncertainty Loss functions for deep learning
Project description
Uncertainty Loss
Loss functions for uncertainty quantification in deep learning.
This package implements loss functions from the following papers
- Evidential Deep Learning to Quantify Classification Uncertainty
ul.evidential_loss
- Information Aware Max-Norm Dirichlet Networks for Predictive Uncertainty Estimation
ul.maxnorm_loss
These loss functions can be used as drop in replacements for
torch.nn.functional.cross_entropy. See QuickStart and Examples below.
Quickstart
Install the package with pip
pip install uncertainty-loss
Then use the loss in a training pipeline. For example:
import uncertainty_loss as ul
import torch
def fit_step(model, x, targets, reg_factor=0):
"""Runs a single training step and retuns the loss for the batch.
Note the inputs to the uncertainty loss function need to be
non-negative. Any transformation will work (exp, relu, softplus,
etc) but we have found that exp works best (in agreement with the
original papers). For convenience we provide a clamped exp function
to avoid overflow.
"""
logits = model(x)
evidence = ul.clamped_exp(logits) # non-negative transform
loss = ul.maxnorm_loss(evidence, targets, reg_factor)
return loss
Examples
Replace
from torch.nn import functional as F
loss = F.cross_entropy(x,y)
With
import uncertainy_loss as ul
loss = ul.evidential_loss(x,y)
# or
loss = ul.maxnorm_loss(x,y)
The loss functions also accept a reduction parameter with the same
properties as the cross_entropy loss.
Important
For each loss function is a regularization term that is shown to be
beneficial for learning to quantify uncertainty. In practice,
to ensure that the regularization term does not dominate early
in training, we ramp up the regularization term from 0 to a max factor
e.g. 0->1. It is up to the user to ensure this happens. Each loss
function takes an additional parameter reg_factor. During training
one can increment reg_factor to accomplish this ramp up. By
default reg_factor==0 so there is no regularization unless
explicitly "turned on"
Example with Regularization Annealing
import uncertainty_loss as ul
reg_steps = 1000
reg_step_size = 1/reg_steps
reg_factor = 0
for epoch in range(epochs):
for x,y in dataloader:
logits = model(x)
evidence = ul.clamped_exp(logits)
loss = ul.maxnorm_loss(evidence, y, reg_factor=reg_factor)
reg_factor = min(reg_factor+reg_step_size, 1)
Motivation
Uncertainty quantification has important applications in AI Safety and active learning. Neural networks trained with a traditional cross entropy loss are often over-confident in unfamiliar situations. It's easy to see why this can be disastrous: An AI surgeon making a confident but wrong incision in an unfamilar situation, a self-driving car making a confident but wrong turn, an AI investor making a confident but wrong buy/sell decision.
There have been several methods proposed for uncertainty quantification. Many of the popular methods require specific network architectures (e.g. Monte Carlo Dropout requires dropout layers) or require expensive inference (Monte Carlo dropout requires multiple runs through the same model, ensemble methods require multiple models).
Recently methods for uncertainty quantification have been proposed that do not require any changes to the network architecture and have no inference overhead. Instead they propose to learn parameters of a "higher order distribution" and use this distribution to quantify the uncertainty in the prediction. They have been shown to be effective.
Unfortunately, these methods haven't been integrated into any of the main deep learning packages and the heavy math makes the implementation a bit tricky.
For these reasons we have created the uncertainty-loss package.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file uncertainty_loss-0.1.2.tar.gz.
File metadata
- Download URL: uncertainty_loss-0.1.2.tar.gz
- Upload date:
- Size: 7.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.2.2 CPython/3.9.5 Linux/5.15.0-46-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7d09f0e6901356ef437d2b802ed9db1e46d6970a6f795e18a0fcde9c291b0197
|
|
| MD5 |
cc101aa28743f7eff400954f17c663f2
|
|
| BLAKE2b-256 |
69245c3d0a820b5ce3bd7c3336b936fc2c1dc23a923c5e336919c7c6e959424c
|
File details
Details for the file uncertainty_loss-0.1.2-py3-none-any.whl.
File metadata
- Download URL: uncertainty_loss-0.1.2-py3-none-any.whl
- Upload date:
- Size: 7.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.2.2 CPython/3.9.5 Linux/5.15.0-46-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
597ab50230f2c843916ed182640b7c22adf114a6f0cedd040798f77d6eaf99f0
|
|
| MD5 |
e814f70fefb21c8376851d241d95e635
|
|
| BLAKE2b-256 |
dc32b855fa38156bb19295f0a32d993a3d583db8a547b6b93b5524bcabf8e5fa
|