MC Dropout (Gal & Ghahramani, 2016) - Pytorch
Project description
MC Dropout, in Pytorch
Implementation of Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning (Gal & Ghahramani, ICML 2016) in Pytorch.
Standard dropout NNs cast as approximate Bayesian inference over deep Gaussian processes — giving free, calibrated uncertainty estimates with no architectural changes and zero inference overhead beyond T forward passes.
Install
$ pip install mc-dropout-pytorch
Usage
Regression with uncertainty
import torch
from torch.utils.data import TensorDataset
from mc_dropout_pytorch import BayesianMLP, MCDropoutInference, Trainer
# build model
model = BayesianMLP(
input_dim = 1,
output_dim = 1,
hidden_dims = (256, 256),
dropout_rate = 0.1,
activation = 'relu',
)
# wrap for MC inference (T=50 stochastic passes)
mc = MCDropoutInference(model, num_samples = 50, task = 'regression', tau = 1.0)
x = torch.linspace(-3, 3, 100).unsqueeze(-1)
out = mc(x)
out.mean # predictive mean — (100, 1)
out.variance # predictive variance — (100, 1) includes τ⁻¹ noise term
out.samples # raw samples — (50, 100, 1)
Classification with predictive entropy
import torch
from mc_dropout_pytorch import BayesianCNN, MCDropoutInference
model = BayesianCNN(
in_channels = 1,
num_classes = 10,
base_channels = 32,
dropout_rate = 0.25,
fc_dropout_rate = 0.5,
img_size = 28,
)
mc = MCDropoutInference(model, num_samples = 50, task = 'classification')
x = torch.randn(8, 1, 28, 28)
out = mc(x)
out.mean # class probabilities — (8, 10)
out.variance # per-class variance — (8, 10)
# active learning signals (§6)
H = mc.predictive_entropy(x) # (8,) — total uncertainty
MI = mc.mutual_information(x) # (8,) — epistemic uncertainty only
Full training loop with the Trainer
import torch
from torch.utils.data import TensorDataset
from mc_dropout_pytorch import BayesianMLP, Trainer
# synthetic regression dataset
X = torch.randn(1000, 4)
y = X[:, 0] * 2 + X[:, 1] - X[:, 2] + torch.randn(1000) * 0.1
dataset = TensorDataset(X, y)
model = BayesianMLP(
input_dim = 4,
output_dim = 1,
hidden_dims = (128, 128),
dropout_rate = 0.1,
)
trainer = Trainer(
model,
dataset,
task = 'regression',
train_lr = 1e-3,
train_num_steps = 5_000,
train_batch_size = 64,
ema_decay = 0.995,
amp = False,
weight_decay = 1e-4, # ≡ prior precision in §3
tau = 1.0, # noise precision
num_mc_samples = 50,
)
trainer.train()
# inference via EMA model
mc = trainer.inference
out = mc(X[:10])
print(out.mean, out.variance)
Multi-GPU
$ accelerate config
$ accelerate launch train.py
Key ideas from the paper
The insight (§3): Training a NN with dropout and L2 regularisation minimises a KL divergence to the posterior of a deep Gaussian process — no variational EM, no weight sampling required.
Test-time dropout (MC Dropout):
for t = 1 … T:
ŷ_t = f^ω_t(x) # ω_t ~ q(ω) via Bernoulli dropout
E[y*] ≈ (1/T) Σ ŷ_t # predictive mean
Var[y*] ≈ τ⁻¹ I + (1/T) Σ ŷ_t ŷ_tᵀ − E[y*]² # predictive variance (Eq. 9)
Active learning (§6): Use mc.mutual_information(x) to identify the most informative unlabelled points — pure epistemic uncertainty, disentangled from aleatoric noise.
Weight correspondence (§3.2):
| Dropout training | Bayesian GP posterior |
|---|---|
dropout probability p |
variational parameter |
L2 weight decay λ |
prior precision |
noise precision τ |
τ = (2N λ) / (1 − p) |
Citations
@article{Gal2016Dropout,
title = {Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning},
author = {Yarin Gal and Zoubin Ghahramani},
journal = {Proceedings of the 33rd International Conference on Machine Learning (ICML)},
year = {2016},
url = {https://arxiv.org/abs/1506.02142}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mc_dropout_pytorch-0.1.0.tar.gz.
File metadata
- Download URL: mc_dropout_pytorch-0.1.0.tar.gz
- Upload date:
- Size: 8.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c0888187562d4aa0b56462a73b3dd0a6c800073a81d8c13e23261cb9a7a81134
|
|
| MD5 |
4e4965fa230c35e88cef3b76ba3f3a9b
|
|
| BLAKE2b-256 |
3161d98aeabb4450ab985825bd31e8c301b0365b188fcb64eac8c99b3225f6e1
|
File details
Details for the file mc_dropout_pytorch-0.1.0-py3-none-any.whl.
File metadata
- Download URL: mc_dropout_pytorch-0.1.0-py3-none-any.whl
- Upload date:
- Size: 9.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
033cdf77e89619959b5ce278745e952a5db7da6015e7974398854bf745c6d1d3
|
|
| MD5 |
e85c5a0e8b71a17b6cc6bc5bf1df1a12
|
|
| BLAKE2b-256 |
b9cf00a7a9d0982d6ffcc9ed3cddcc72938a5d875670e78b2fe14c91611bd754
|