Skip to main content

MC Dropout (Gal & Ghahramani, 2016) - Pytorch

Project description

MC Dropout, in Pytorch

PyPI version

Implementation of Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning (Gal & Ghahramani, ICML 2016) in Pytorch.

Standard dropout NNs cast as approximate Bayesian inference over deep Gaussian processes — giving free, calibrated uncertainty estimates with no architectural changes and zero inference overhead beyond T forward passes.

Install

$ pip install mc-dropout-pytorch

Usage

Regression with uncertainty

import torch
from torch.utils.data import TensorDataset
from mc_dropout_pytorch import BayesianMLP, MCDropoutInference, Trainer

# build model
model = BayesianMLP(
    input_dim    = 1,
    output_dim   = 1,
    hidden_dims  = (256, 256),
    dropout_rate = 0.1,
    activation   = 'relu',
)

# wrap for MC inference (T=50 stochastic passes)
mc = MCDropoutInference(model, num_samples = 50, task = 'regression', tau = 1.0)

x = torch.linspace(-3, 3, 100).unsqueeze(-1)
out = mc(x)

out.mean      # predictive mean     — (100, 1)
out.variance  # predictive variance — (100, 1)  includes τ⁻¹ noise term
out.samples   # raw samples         — (50, 100, 1)

Classification with predictive entropy

import torch
from mc_dropout_pytorch import BayesianCNN, MCDropoutInference

model = BayesianCNN(
    in_channels      = 1,
    num_classes      = 10,
    base_channels    = 32,
    dropout_rate     = 0.25,
    fc_dropout_rate  = 0.5,
    img_size         = 28,
)

mc = MCDropoutInference(model, num_samples = 50, task = 'classification')

x = torch.randn(8, 1, 28, 28)
out = mc(x)

out.mean      # class probabilities  — (8, 10)
out.variance  # per-class variance   — (8, 10)

# active learning signals (§6)
H  = mc.predictive_entropy(x)   # (8,)  — total uncertainty
MI = mc.mutual_information(x)   # (8,)  — epistemic uncertainty only

Full training loop with the Trainer

import torch
from torch.utils.data import TensorDataset
from mc_dropout_pytorch import BayesianMLP, Trainer

# synthetic regression dataset
X = torch.randn(1000, 4)
y = X[:, 0] * 2 + X[:, 1] - X[:, 2] + torch.randn(1000) * 0.1
dataset = TensorDataset(X, y)

model = BayesianMLP(
    input_dim    = 4,
    output_dim   = 1,
    hidden_dims  = (128, 128),
    dropout_rate = 0.1,
)

trainer = Trainer(
    model,
    dataset,
    task             = 'regression',
    train_lr         = 1e-3,
    train_num_steps  = 5_000,
    train_batch_size = 64,
    ema_decay        = 0.995,
    amp              = False,
    weight_decay     = 1e-4,   # ≡ prior precision in §3
    tau              = 1.0,    # noise precision
    num_mc_samples   = 50,
)

trainer.train()

# inference via EMA model
mc = trainer.inference
out = mc(X[:10])
print(out.mean, out.variance)

Multi-GPU

$ accelerate config
$ accelerate launch train.py

Key ideas from the paper

The insight (§3): Training a NN with dropout and L2 regularisation minimises a KL divergence to the posterior of a deep Gaussian process — no variational EM, no weight sampling required.

Test-time dropout (MC Dropout):

for t = 1 … T:
    ŷ_t = f^ω_t(x)    # ω_t ~ q(ω)  via Bernoulli dropout

E[y*]   ≈ (1/T) Σ ŷ_t                              # predictive mean
Var[y*] ≈ τ⁻¹ I + (1/T) Σ ŷ_t ŷ_tᵀ − E[y*]²     # predictive variance  (Eq. 9)

Active learning (§6): Use mc.mutual_information(x) to identify the most informative unlabelled points — pure epistemic uncertainty, disentangled from aleatoric noise.

Weight correspondence (§3.2):

Dropout training Bayesian GP posterior
dropout probability p variational parameter
L2 weight decay λ prior precision
noise precision τ τ = (2N λ) / (1 − p)

Citations

@article{Gal2016Dropout,
    title   = {Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning},
    author  = {Yarin Gal and Zoubin Ghahramani},
    journal = {Proceedings of the 33rd International Conference on Machine Learning (ICML)},
    year    = {2016},
    url     = {https://arxiv.org/abs/1506.02142}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mc_dropout_pytorch-0.1.1.tar.gz (8.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mc_dropout_pytorch-0.1.1-py3-none-any.whl (9.2 kB view details)

Uploaded Python 3

File details

Details for the file mc_dropout_pytorch-0.1.1.tar.gz.

File metadata

  • Download URL: mc_dropout_pytorch-0.1.1.tar.gz
  • Upload date:
  • Size: 8.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for mc_dropout_pytorch-0.1.1.tar.gz
Algorithm Hash digest
SHA256 aada563bba18710235f8562d342987a78d9deefda912870e85445bc4896fe5c5
MD5 bbf960d71483dad8672b6ceec7a70a52
BLAKE2b-256 465314d2d83ce454847a1c49e52c9293d6644d849748d72088c40f11a537918e

See more details on using hashes here.

File details

Details for the file mc_dropout_pytorch-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for mc_dropout_pytorch-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 1976ee8d4079c74fed7f6e5aa4dcb0615aa5c92dd0429b33d491cdbc26a49b51
MD5 70658fdb146264c3ecc6e58c252a0eaa
BLAKE2b-256 6e74e9f12624271eae170f4998711a4699e0cb65b12230e7287a8e1dda182a9b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page