Uncertainty quantification with PyTorch
Project description
Installation | Quickstart | Methods | Friends | Contributing | Citation | Documentation | Paper
What is posteriors
?
General purpose python library for uncertainty quantification with PyTorch
.
- Composable: Use with
transformers
,lightning
,torchopt
,torch.distributions
and more! - Extensible: Add new methods! Add new models!
- Functional: Easier to test, closer to mathematics!
- Scalable: Big model? Big data? No problem!
- Swappable: Swap between algorithms with ease!
Installation
posteriors
is available on PyPI and can be installed via pip
:
pip install posteriors
Quickstart
posteriors
is functional first and aims to be easy to use and extend. Let's try it out
by training a simple model with variational inference:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch import nn, utils, func
import torchopt
import posteriors
dataset = MNIST(root="./data", transform=ToTensor())
train_loader = utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
num_data = len(dataset)
classifier = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 10))
params = dict(classifier.named_parameters())
def log_posterior(params, batch):
images, labels = batch
images = images.view(images.size(0), -1)
output = func.functional_call(classifier, params, images)
log_post_val = (
-nn.functional.cross_entropy(output, labels)
+ posteriors.diag_normal_log_prob(params) / num_data
)
return log_post_val, output
transform = posteriors.vi.diag.build(
log_posterior, torchopt.adam(), temperature=1 / num_data
) # Can swap out for any posteriors algorithm
state = transform.init(params)
for batch in train_loader:
state = transform.update(state, batch)
Observe that posteriors
recommends specifying log_posterior
and temperature
such that
log_posterior
remains on the same scale for different batch sizes. posteriors
algorithms are designed to be stable as temperature
goes to zero.
Further, the output of log_posterior
is a tuple containing the evaluation
(single-element Tensor) and an additional argument (TensorTree) containing any
auxiliary information we'd like to retain from the model call, here the model predictions.
If you have no auxiliary information, you can simply return torch.tensor([])
as
the second element. For more info see torch.func.grad
(with has_aux=True
) or the documentation.
Check out the tutorials for more detailed usage!
Methods
posteriors
supports a variety of methods for uncertainty quantification, including:
With full details available in the API documentation.
posteriors
is designed to be easily extensible, if you're favorite method is not listed above,
raise an issue and we'll see what we can do!
Friends
Interfaces seamlessly with:
torch
and in particulartorch.func
.torch.distributions
for distributions and sampling, (note that it's typically required to setvalidate_args=False
to conform with the control flows intorch.func
).- Functional and flexible torch optimizers from
torchopt
. transformers
for pre-trained models.lightning
for convenient training and logging, see examples/lightning_autoencoder.py.
The functional transform interface is strongly inspired by frameworks such as
optax
and blackjax
.
As well as other UQ libraries fortuna
,
laplace
, numpyro
,
pymc
and uncertainty-baselines
.
Contributing
You can report a bug or request a feature by creating a new issue on GitHub.
If you want to contribute code, please check the contributing guide.
Citation
If you use posteriors
in your research, please cite the library using the following BibTeX entry:
@article{duffield2024scalable,
title={Scalable Bayesian Learning with posteriors},
author={Duffield, Samuel and Donatella, Kaelan and Chiu, Johnathan and Klett, Phoebe and Simpson, Daniel},
journal={arXiv preprint arXiv:2406.00104},
year={2024}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file posteriors-0.0.4.tar.gz
.
File metadata
- Download URL: posteriors-0.0.4.tar.gz
- Upload date:
- Size: 36.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.0 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3f9363acd2964a3474d7cf904ddeba6473c4571c2a87ff20b259f9e396bb23f3 |
|
MD5 | f93f8f9fa500eb6e026f4d899b6f643f |
|
BLAKE2b-256 | 24353bbe2a5c43571483c5d0cc79d99b58deb4603b899e62654ef971aee3b3bd |
File details
Details for the file posteriors-0.0.4-py3-none-any.whl
.
File metadata
- Download URL: posteriors-0.0.4-py3-none-any.whl
- Upload date:
- Size: 42.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.0 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 52c00e851eb2153e37bacefe228263838b92d725609caa782c6388eb544f2644 |
|
MD5 | 4fed40f62efeefc89a15e7941410a144 |
|
BLAKE2b-256 | 482be998be529b367ae7d419d14cd7ff226c99bf960ec49cfa2b124f12fa46d9 |