Skip to main content

MC Dropout (Gal & Ghahramani, 2016) - Pytorch

Project description

MC Dropout, in Pytorch

PyPI version

Implementation of Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning (Gal & Ghahramani, ICML 2016) in Pytorch.

Standard dropout NNs cast as approximate Bayesian inference over deep Gaussian processes — giving free, calibrated uncertainty estimates with no architectural changes and zero inference overhead beyond T forward passes.

Install

$ pip install mc-dropout-pytorch

Usage

Regression with uncertainty

import torch
from torch.utils.data import TensorDataset
from mc_dropout_pytorch import BayesianMLP, MCDropoutInference, Trainer

# build model
model = BayesianMLP(
    input_dim    = 1,
    output_dim   = 1,
    hidden_dims  = (256, 256),
    dropout_rate = 0.1,
    activation   = 'relu',
)

# wrap for MC inference (T=50 stochastic passes)
mc = MCDropoutInference(model, num_samples = 50, task = 'regression', tau = 1.0)

x = torch.linspace(-3, 3, 100).unsqueeze(-1)
out = mc(x)

out.mean      # predictive mean     — (100, 1)
out.variance  # predictive variance — (100, 1)  includes τ⁻¹ noise term
out.samples   # raw samples         — (50, 100, 1)

Classification with predictive entropy

import torch
from mc_dropout_pytorch import BayesianCNN, MCDropoutInference

model = BayesianCNN(
    in_channels      = 1,
    num_classes      = 10,
    base_channels    = 32,
    dropout_rate     = 0.25,
    fc_dropout_rate  = 0.5,
    img_size         = 28,
)

mc = MCDropoutInference(model, num_samples = 50, task = 'classification')

x = torch.randn(8, 1, 28, 28)
out = mc(x)

out.mean      # class probabilities  — (8, 10)
out.variance  # per-class variance   — (8, 10)

# active learning signals (§6)
H  = mc.predictive_entropy(x)   # (8,)  — total uncertainty
MI = mc.mutual_information(x)   # (8,)  — epistemic uncertainty only

Full training loop with the Trainer

import torch
from torch.utils.data import TensorDataset
from mc_dropout_pytorch import BayesianMLP, Trainer

# synthetic regression dataset
X = torch.randn(1000, 4)
y = X[:, 0] * 2 + X[:, 1] - X[:, 2] + torch.randn(1000) * 0.1
dataset = TensorDataset(X, y)

model = BayesianMLP(
    input_dim    = 4,
    output_dim   = 1,
    hidden_dims  = (128, 128),
    dropout_rate = 0.1,
)

trainer = Trainer(
    model,
    dataset,
    task             = 'regression',
    train_lr         = 1e-3,
    train_num_steps  = 5_000,
    train_batch_size = 64,
    ema_decay        = 0.995,
    amp              = False,
    weight_decay     = 1e-4,   # ≡ prior precision in §3
    tau              = 1.0,    # noise precision
    num_mc_samples   = 50,
)

trainer.train()

# inference via EMA model
mc = trainer.inference
out = mc(X[:10])
print(out.mean, out.variance)

Multi-GPU

$ accelerate config
$ accelerate launch train.py

Key ideas from the paper

The insight (§3): Training a NN with dropout and L2 regularisation minimises a KL divergence to the posterior of a deep Gaussian process — no variational EM, no weight sampling required.

Test-time dropout (MC Dropout):

for t = 1 … T:
    ŷ_t = f^ω_t(x)    # ω_t ~ q(ω)  via Bernoulli dropout

E[y*]   ≈ (1/T) Σ ŷ_t                              # predictive mean
Var[y*] ≈ τ⁻¹ I + (1/T) Σ ŷ_t ŷ_tᵀ − E[y*]²     # predictive variance  (Eq. 9)

Active learning (§6): Use mc.mutual_information(x) to identify the most informative unlabelled points — pure epistemic uncertainty, disentangled from aleatoric noise.

Weight correspondence (§3.2):

Dropout training Bayesian GP posterior
dropout probability p variational parameter
L2 weight decay λ prior precision
noise precision τ τ = (2N λ) / (1 − p)

Citations

@article{Gal2016Dropout,
    title   = {Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning},
    author  = {Yarin Gal and Zoubin Ghahramani},
    journal = {Proceedings of the 33rd International Conference on Machine Learning (ICML)},
    year    = {2016},
    url     = {https://arxiv.org/abs/1506.02142}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mc_dropout_pytorch-0.1.2.tar.gz (8.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mc_dropout_pytorch-0.1.2-py3-none-any.whl (9.2 kB view details)

Uploaded Python 3

File details

Details for the file mc_dropout_pytorch-0.1.2.tar.gz.

File metadata

  • Download URL: mc_dropout_pytorch-0.1.2.tar.gz
  • Upload date:
  • Size: 8.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for mc_dropout_pytorch-0.1.2.tar.gz
Algorithm Hash digest
SHA256 22731d24c9731aa7034ab7e3ae748bcca4567f85f81a5cba335a0243d942eef6
MD5 efad497876eb3045c21e64c9b2b67ca8
BLAKE2b-256 76a4c629b1fe6fe725df46d833ae557a1253998fcf981a51faacd7d51297c092

See more details on using hashes here.

File details

Details for the file mc_dropout_pytorch-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for mc_dropout_pytorch-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 132ff61b00a2a5934684316585df7e6c5c46f1312d6781964b2017ea1af5f7b0
MD5 9adf2c954968c7200a0b4a0b0d4b5495
BLAKE2b-256 131b5d773fb1e2c3151f36037b2bdca6d21d105e27fc2a6cbf88531eec4b7ef1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page