Skip to main content

PyTorch uncertainty quantification toolkit with Bayes-by-Backprop VI, Laplace, SGLD, MC Dropout, and Gaussian Processes.

Project description

deepuq

Unified deep learning uncertainty quantification (UQ) toolkit in PyTorch.

Implements five widely used methods:

  1. Variational Inference (VI) — Bayes by Backprop with BayesianLinear layers.
  2. Laplace Approximation — native diagonal-family backends (diag, fisher_diag, lowrank_diag, block_diag) plus laplace-torch backends (kron, full).
  3. MCMC (SGLD) — Stochastic Gradient Langevin Dynamics sampler for NN posteriors.
  4. MC Dropout — Keep dropout active at test-time and aggregate Monte Carlo predictions.
  5. Gaussian Processes (GPs) — Exact regression and sparse inducing-point approximations with RBF kernels.

Examples and tutorials focus on a synthetic Euler-Bernoulli beam deflection regression task to illustrate confidence bounds.

Method Summary

Method Family Implemented Variants Main Wrapper / Class Tutorial
Variational Inference Bayes by Backprop BayesianLinear, vi_elbo_step notebooks/BayesByBackprop_Tutorial.ipynb
Laplace Approximation diag, fisher_diag, lowrank_diag, block_diag, kron, full LaplaceWrapper notebooks/laplace/Laplace_HessianComparison_Tutorial.ipynb
MCMC Stochastic Gradient Langevin Dynamics SGLDSampler notebooks/SGLD_Tutorial.ipynb
MC Dropout Monte Carlo dropout inference MCDropoutWrapper notebooks/MC_Dropout_Tutorial.ipynb
Gaussian Process Exact GP (RBFKernel) GaussianProcessRegressor notebooks/GaussianProcess_Tutorial.ipynb
Sparse GP Variational inducing-point GP SparseGaussianProcessRegressor notebooks/SparseGaussianProcess_Tutorial.ipynb

Install (local)

git clone https://github.com/Vispikarkaria/Deep-UQ.git
cd Deep-UQ
pip install -e .

Install (PyPI)

pip install uqdeepnn

For LaplaceWrapper structures kron and full, install the optional backend:

pip install "laplace-torch>=0.1.7"

Publish / Update PyPI Release

Use this flow whenever you want to publish a new pip version.

  1. Bump version in pyproject.toml:
[project]
version = "0.1.2"
  1. Commit and push the version bump:
git add pyproject.toml
git commit -m "Bump version to 0.1.2"
git push
  1. Build distributions:
python -m pip install --upgrade build twine
python -m build
  1. Validate package metadata:
python -m twine check dist/*
  1. Upload to TestPyPI (recommended first):
python -m twine upload --repository testpypi dist/*
  1. Upload to PyPI:
python -m twine upload dist/*
  1. Verify installation:
pip install -U uqdeepnn
python -c "import deepuq; print('deepuq import ok')"

Notes:

  • Prefer API tokens over passwords for Twine auth.
  • Revoke and rotate any token immediately if it is ever exposed.

Quickstart

import torch
from deepuq.models import MLP
from deepuq.methods import MCDropoutWrapper

# Beam deflection regression input grid
L = 2.0
x = torch.linspace(0.0, L, 200).unsqueeze(-1)

# After training an MLP, enable MC Dropout for uncertainty estimates
model = MLP(input_dim=1, hidden_dims=[128, 128], output_dim=1, p_drop=0.15)
uq = MCDropoutWrapper(model, n_mc=200, apply_softmax=False)
mean, var = uq.predict(x)
print(mean.shape, var.shape)

See the examples/ folder for end-to-end regression scripts on the Euler-Bernoulli beam deflection problem.

Methods

  • VI: Place Gaussian posteriors over weights with reparameterization trick and KL regularization.
  • Laplace: Fit a Gaussian around a MAP solution using one of multiple curvature structures (diag, fisher_diag, lowrank_diag, block_diag, kron, full) and calibrate with a prior precision.
  • MCMC (SGLD): Inject Gaussian noise into SGD steps to sample from the posterior.
  • MC Dropout: Use dropout at inference; Monte Carlo average for mean and variance.
  • Gaussian Processes: Closed-form posterior inference with RBF kernels for regression and uncertainty-aware interpolation.

For Laplace users:

  • Native backends (diag, fisher_diag, lowrank_diag, block_diag) work without extra runtime dependencies.
  • kron and full use laplace-torch under the hood.

Tutorials

  • notebooks/BayesByBackprop_Tutorial.ipynb: Variational Inference (Bayes by Backprop) for regression with predictive uncertainty.
  • notebooks/MC_Dropout_Tutorial.ipynb: MC Dropout tutorial on a nonlinear beam-style regression case.
  • notebooks/laplace/Laplace_Tutorial.ipynb: Core Laplace workflow around a MAP model.
  • notebooks/laplace/Laplace_FullHessian_Tutorial.ipynb: Full-Hessian Laplace example (requires laplace-torch).
  • notebooks/laplace/Laplace_HessianComparison_Tutorial.ipynb: Side-by-side comparison of all Hessian structures (diag, fisher_diag, lowrank_diag, block_diag, kron, full) using shared MAP weights and common metrics (RMSE, NLL, coverage, interval width, ID/OOD uncertainty ratio).
  • notebooks/SGLD_Tutorial.ipynb: MCMC posterior sampling with SGLD.
  • notebooks/GaussianProcess_Tutorial.ipynb: Exact Gaussian Process regression.
  • notebooks/SparseGaussianProcess_Tutorial.ipynb: Sparse variational GP with inducing points.

Gaussian Processes

The module deepuq.models.gaussian_process provides lightweight GP utilities implemented entirely in PyTorch so everything can run on CPU or GPU.

Exact GP

GaussianProcessRegressor implements closed-form inference with an RBF kernel. The API mirrors scikit-learn while keeping tensors on the chosen device.

import torch
from deepuq.models import GaussianProcessRegressor, RBFKernel

# Training data
x = torch.linspace(-1.0, 1.0, 40).unsqueeze(-1)
y = torch.sin(2 * torch.pi * x) + 0.05 * torch.randn_like(x)

# Model setup
kernel = RBFKernel(lengthscale=0.5, outputscale=1.0)
gp = GaussianProcessRegressor(kernel=kernel, noise=0.02)
gp.fit(x, y)

# Posterior predictions
x_star = torch.linspace(-1.5, 1.5, 200).unsqueeze(-1)
mean, var = gp.predict(x_star)
samples = gp.posterior_samples(x_star, n_samples=5)

Sparse Variational GP

SparseGaussianProcessRegressor follows the variational inducing-point approach of Titsias (2009), optimising kernel hyperparameters and inducing locations with Adam for scalability.

import torch
from deepuq.models import SparseGaussianProcessRegressor

x = torch.linspace(-2.0, 2.0, 500).unsqueeze(-1)
y = torch.sin(2 * torch.pi * x) + 0.1 * torch.randn_like(x)

sparse_gp = SparseGaussianProcessRegressor(num_inducing=40, num_iterations=800)
sparse_gp.fit(x, y)
mean, var = sparse_gp.predict(x[:50])

Explore both flavours in the notebooks notebooks/GaussianProcess_Tutorial.ipynb (exact GP) and notebooks/SparseGaussianProcess_Tutorial.ipynb (sparse GP), each of which visualises posterior means, credible intervals, and posterior samples on toy datasets.

Documentation

  • API docs are in each module and the README sections below.
  • Run pydoc deepuq.methods.vi etc., or open the examples.

Contributing

PRs welcome. Please add tests under tests/ and run pytest.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

uqdeepnn-0.1.2.tar.gz (24.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

uqdeepnn-0.1.2-py3-none-any.whl (20.7 kB view details)

Uploaded Python 3

File details

Details for the file uqdeepnn-0.1.2.tar.gz.

File metadata

  • Download URL: uqdeepnn-0.1.2.tar.gz
  • Upload date:
  • Size: 24.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.4

File hashes

Hashes for uqdeepnn-0.1.2.tar.gz
Algorithm Hash digest
SHA256 fd9a8696147e93f7ba112324db1769c7eacd15528614f7342ec600b2a1f81779
MD5 42e3b6808d2c88fec01f941aa1599a27
BLAKE2b-256 50caeb5d48d07d7a7f2bf6af5ff3b827073e612374a857a234e9f53776aa494d

See more details on using hashes here.

File details

Details for the file uqdeepnn-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: uqdeepnn-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 20.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.4

File hashes

Hashes for uqdeepnn-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 d9e88a287019b43a3d39a0b9a266ee9644a7ed392b79e4820a58b89a20cb86cd
MD5 041fbdc7369a7730b2fd250cb6dba5f2
BLAKE2b-256 0ac38d9a885e1a76ca32983f8af23c29b3511f371f9ccb7eb46a7d39921fe2fa

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page