A library for Bayesian neural network layers and uncertainty estimation in Deep Learning
Project description
A library for Bayesian neural network layers and uncertainty estimation in Deep Learning
Get Started | Example usage | Documentation | Citing
Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable Bayesian inference in deep learning models to quantify principled uncertainty estimates in model predictions.
Overview
Bayesian-Torch is designed to be flexible and enables seamless extension of deterministic deep neural network model to corresponding Bayesian form by simply replacing the deterministic layers with Bayesian layers. It enables user to perform stochastic variational inference in deep neural networks.
Bayesian layers:
-
Variational layers with reparameterized Monte Carlo estimators [Blundell et al. 2015]
LinearReparameterization Conv1dReparameterization, Conv2dReparameterization, Conv3dReparameterization, ConvTranspose1dReparameterization, ConvTranspose2dReparameterization, ConvTranspose3dReparameterization LSTMReparameterization
-
Variational layers with Flipout Monte Carlo estimators [Wen et al. 2018]
LinearFlipout Conv1dFlipout, Conv2dFlipout, Conv3dFlipout, ConvTranspose1dFlipout, ConvTranspose2dFlipout, ConvTranspose3dFlipout LSTMFlipout
Key features:
- dnn_to_bnn(): Seamless conversion of model to be Uncertainty-aware with single line of code. An API to convert deterministic deep neural network (dnn) model of any architecture to Bayesian deep neural network (bnn) model, simplifying the model definition i.e. drop-in replacements of Convolutional, Linear and LSTM layers to corresponding Bayesian layers. This will enable seamless conversion of existing topology of larger models to Bayesian deep neural network models for extending towards uncertainty-aware applications.
- MOPED: Scale Bayesian inference to large scale models by specifying weight priors and variational posteriors in Bayesian neural networks with Empirical Bayes [Krishnan et al. 2020]
- Quantization: Post Training Quantization of Bayesian deep neural network models for INT8 inference with simple API's [Lin et al. 2023]
- AvUC: Accuracy versus Uncertainty Calibration loss [Krishnan and Tickoo 2020]
Installing Bayesian-Torch
To install core library using pip
:
pip install bayesian-torch
To install latest development version from source:
git clone https://github.com/IntelLabs/bayesian-torch
cd bayesian-torch
pip install .
Usage
There are two ways to build Bayesian deep neural networks using Bayesian-Torch:
- Convert an existing deterministic deep neural network (dnn) model to Bayesian deep neural network (bnn) model with dnn_to_bnn() API
- Define your custom model using the Bayesian layers (Reparameterization or Flipout)
(1) For instance, building Bayesian-ResNet18 from torchvision deterministic ResNet18 model is as simple as:
import torch
import torchvision
from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss
const_bnn_prior_parameters = {
"prior_mu": 0.0,
"prior_sigma": 1.0,
"posterior_mu_init": 0.0,
"posterior_rho_init": -3.0,
"type": "Reparameterization", # Flipout or Reparameterization
"moped_enable": False, # True to initialize mu/sigma from the pretrained dnn weights
"moped_delta": 0.5,
}
model = torchvision.models.resnet18()
dnn_to_bnn(model, const_bnn_prior_parameters)
To use MOPED method i.e. setting the prior and initializing variational parameters from a pretrained deterministic model (helps training convergence of larger models):
const_bnn_prior_parameters = {
"prior_mu": 0.0,
"prior_sigma": 1.0,
"posterior_mu_init": 0.0,
"posterior_rho_init": -3.0,
"type": "Reparameterization", # Flipout or Reparameterization
"moped_enable": True, # True to initialize mu/sigma from the pretrained dnn weights
"moped_delta": 0.5,
}
model = torchvision.models.resnet18(pretrained=True)
dnn_to_bnn(model, const_bnn_prior_parameters)
Training snippet:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), args.learning_rate)
output = model(x_train)
kl = get_kl_loss(model)
ce_loss = criterion(output, y_train)
loss = ce_loss + kl / args.batch_size
loss.backward()
optimizer.step()
Testing snippet:
model.eval()
with torch.no_grad():
output_mc = []
for mc_run in range(args.num_monte_carlo):
logits = model(x_test)
probs = torch.nn.functional.softmax(logits, dim=-1)
output_mc.append(probs)
output = torch.stack(output_mc)
pred_mean = output.mean(dim=0)
y_pred = torch.argmax(pred_mean, axis=-1)
test_acc = (y_pred.data.cpu().numpy() == y_test.data.cpu().numpy()).mean()
Uncertainty Quantification:
from utils.util import predictive_entropy, mutual_information
predictive_uncertainty = predictive_entropy(output.data.cpu().numpy())
model_uncertainty = mutual_information(output.data.cpu().numpy())
(2) For building custom models, we have provided example model implementations using the Bayesian layers.
Example usage (training and evaluation of models)
We have provided example usages and scripts to train/evaluate the models. The instructions for CIFAR10 examples is provided below, similar scripts for ImageNet and MNIST are available.
cd bayesian_torch
Training
To train Bayesian ResNet on CIFAR10, run this command:
Mean-field variational inference (Reparameterized Monte Carlo estimator)
sh scripts/train_bayesian_cifar.sh
Mean-field variational inference (Flipout Monte Carlo estimator)
sh scripts/train_bayesian_flipout_cifar.sh
To train deterministic ResNet on CIFAR10, run this command:
Vanilla
sh scripts/train_deterministic_cifar.sh
Evaluation
To evaluate Bayesian ResNet on CIFAR10, run this command:
Mean-field variational inference (Reparameterized Monte Carlo estimator)
sh scripts/test_bayesian_cifar.sh
Mean-field variational inference (Flipout Monte Carlo estimator)
sh scripts/test_bayesian_flipout_cifar.sh
To evaluate deterministic ResNet on CIFAR10, run this command:
Vanilla
sh scripts/test_deterministic_cifar.sh
Post Training Quantization (PTQ)
To quantize Bayesian ResNet (convert to INT8) and evaluate on CIFAR10, run this command:
sh scripts/quantize_bayesian_cifar.sh
Citing
If you use this code, please cite as:
@software{krishnan2022bayesiantorch,
author = {Ranganath Krishnan and Pi Esposito and Mahesh Subedar},
title = {Bayesian-Torch: Bayesian neural network layers for uncertainty estimation},
month = jan,
year = 2022,
doi = {10.5281/zenodo.5908307},
url = {https://doi.org/10.5281/zenodo.5908307}
howpublished = {\url{https://github.com/IntelLabs/bayesian-torch}}
}
Accuracy versus Uncertainty Calibration (AvUC) loss
@inproceedings{NEURIPS2020_d3d94468,
title = {Improving model calibration with accuracy versus uncertainty optimization},
author = {Krishnan, Ranganath and Tickoo, Omesh},
booktitle = {Advances in Neural Information Processing Systems},
volume = {33},
pages = {18237--18248},
year = {2020},
url = {https://proceedings.neurips.cc/paper/2020/file/d3d9446802a44259755d38e6d163e820-Paper.pdf}
}
Quantization framework for Bayesian deep learning
@inproceedings{lin2023quantization,
title={Quantization for Bayesian Deep Learning: Low-Precision Characterization and Robustness},
author={Lin, Jun-Liang and Krishnan, Ranganath and Ranipa, Keyur Ruganathbhai and Subedar, Mahesh and Sanghavi, Vrushabh and Arunachalam, Meena and Tickoo, Omesh and Iyer, Ravishankar and Kandemir, Mahmut Taylan},
booktitle={2023 IEEE International Symposium on Workload Characterization (IISWC)},
pages={180--192},
year={2023},
organization={IEEE}
}
Model Priors with Empirical Bayes using DNN (MOPED)
@inproceedings{krishnan2020specifying,
title={Specifying weight priors in bayesian deep neural networks with empirical bayes},
author={Krishnan, Ranganath and Subedar, Mahesh and Tickoo, Omesh},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={34},
number={04},
pages={4477--4484},
year={2020},
url = {https://ojs.aaai.org/index.php/AAAI/article/view/5875}
}
This library and code is intended for researchers and developers, enables to quantify principled uncertainty estimates in deep learning models to develop uncertainty-aware AI models. Feedbacks, issues and contributions are welcome.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file bayesian-torch-0.5.0.tar.gz
.
File metadata
- Download URL: bayesian-torch-0.5.0.tar.gz
- Upload date:
- Size: 114.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.7.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d7f9f8e46cdd274a669cd832c297b61a864d1d4bebc749b1dae5a13e5a7741d5 |
|
MD5 | 63ab340ca427031f08ce12b67975a1e3 |
|
BLAKE2b-256 | a2ef8eec6a199dc805e4e7469ca5e00d78fbb6a898ed960e8ded64d29e2a5ec6 |
File details
Details for the file bayesian_torch-0.5.0-py3-none-any.whl
.
File metadata
- Download URL: bayesian_torch-0.5.0-py3-none-any.whl
- Upload date:
- Size: 78.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.7.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | aad7059e03487653610c1a0c3a28e82bce8ea11765347abfa19fe7b3d450d5f0 |
|
MD5 | ca86ead06c4afc8774441db9ab1cc804 |
|
BLAKE2b-256 | a09f9188ad9c06ec0604d329792e570ddc72e15bda6deb57e7f9f00b68d76838 |