Skip to main content

MC Dropout (Gal & Ghahramani, 2016) - Pytorch

Project description

MC Dropout, in Pytorch

PyPI version

Implementation of Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning (Gal & Ghahramani, ICML 2016) in Pytorch.

Standard dropout NNs cast as approximate Bayesian inference over deep Gaussian processes — giving free, calibrated uncertainty estimates with no architectural changes and zero inference overhead beyond T forward passes.

Install

$ pip install mc-dropout-pytorch

Usage

Regression with uncertainty

import torch
from torch.utils.data import TensorDataset
from mc_dropout_pytorch import BayesianMLP, MCDropoutInference, Trainer

# build model
model = BayesianMLP(
    input_dim    = 1,
    output_dim   = 1,
    hidden_dims  = (256, 256),
    dropout_rate = 0.1,
    activation   = 'relu',
)

# wrap for MC inference (T=50 stochastic passes)
mc = MCDropoutInference(model, num_samples = 50, task = 'regression', tau = 1.0)

x = torch.linspace(-3, 3, 100).unsqueeze(-1)
out = mc(x)

out.mean      # predictive mean     — (100, 1)
out.variance  # predictive variance — (100, 1)  includes τ⁻¹ noise term
out.samples   # raw samples         — (50, 100, 1)

Classification with predictive entropy

import torch
from mc_dropout_pytorch import BayesianCNN, MCDropoutInference

model = BayesianCNN(
    in_channels      = 1,
    num_classes      = 10,
    base_channels    = 32,
    dropout_rate     = 0.25,
    fc_dropout_rate  = 0.5,
    img_size         = 28,
)

mc = MCDropoutInference(model, num_samples = 50, task = 'classification')

x = torch.randn(8, 1, 28, 28)
out = mc(x)

out.mean      # class probabilities  — (8, 10)
out.variance  # per-class variance   — (8, 10)

# active learning signals (§6)
H  = mc.predictive_entropy(x)   # (8,)  — total uncertainty
MI = mc.mutual_information(x)   # (8,)  — epistemic uncertainty only

Full training loop with the Trainer

import torch
from torch.utils.data import TensorDataset
from mc_dropout_pytorch import BayesianMLP, Trainer

# synthetic regression dataset
X = torch.randn(1000, 4)
y = X[:, 0] * 2 + X[:, 1] - X[:, 2] + torch.randn(1000) * 0.1
dataset = TensorDataset(X, y)

model = BayesianMLP(
    input_dim    = 4,
    output_dim   = 1,
    hidden_dims  = (128, 128),
    dropout_rate = 0.1,
)

trainer = Trainer(
    model,
    dataset,
    task             = 'regression',
    train_lr         = 1e-3,
    train_num_steps  = 5_000,
    train_batch_size = 64,
    ema_decay        = 0.995,
    amp              = False,
    weight_decay     = 1e-4,   # ≡ prior precision in §3
    tau              = 1.0,    # noise precision
    num_mc_samples   = 50,
)

trainer.train()

# inference via EMA model
mc = trainer.inference
out = mc(X[:10])
print(out.mean, out.variance)

Multi-GPU

$ accelerate config
$ accelerate launch train.py

Key ideas from the paper

The insight (§3): Training a NN with dropout and L2 regularisation minimises a KL divergence to the posterior of a deep Gaussian process — no variational EM, no weight sampling required.

Test-time dropout (MC Dropout):

for t = 1 … T:
    ŷ_t = f^ω_t(x)    # ω_t ~ q(ω)  via Bernoulli dropout

E[y*]   ≈ (1/T) Σ ŷ_t                              # predictive mean
Var[y*] ≈ τ⁻¹ I + (1/T) Σ ŷ_t ŷ_tᵀ − E[y*]²     # predictive variance  (Eq. 9)

Active learning (§6): Use mc.mutual_information(x) to identify the most informative unlabelled points — pure epistemic uncertainty, disentangled from aleatoric noise.

Weight correspondence (§3.2):

Dropout training Bayesian GP posterior
dropout probability p variational parameter
L2 weight decay λ prior precision
noise precision τ τ = (2N λ) / (1 − p)

Citations

@article{Gal2016Dropout,
    title   = {Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning},
    author  = {Yarin Gal and Zoubin Ghahramani},
    journal = {Proceedings of the 33rd International Conference on Machine Learning (ICML)},
    year    = {2016},
    url     = {https://arxiv.org/abs/1506.02142}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mc_dropout_pytorch-0.1.0.tar.gz (8.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mc_dropout_pytorch-0.1.0-py3-none-any.whl (9.2 kB view details)

Uploaded Python 3

File details

Details for the file mc_dropout_pytorch-0.1.0.tar.gz.

File metadata

  • Download URL: mc_dropout_pytorch-0.1.0.tar.gz
  • Upload date:
  • Size: 8.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for mc_dropout_pytorch-0.1.0.tar.gz
Algorithm Hash digest
SHA256 c0888187562d4aa0b56462a73b3dd0a6c800073a81d8c13e23261cb9a7a81134
MD5 4e4965fa230c35e88cef3b76ba3f3a9b
BLAKE2b-256 3161d98aeabb4450ab985825bd31e8c301b0365b188fcb64eac8c99b3225f6e1

See more details on using hashes here.

File details

Details for the file mc_dropout_pytorch-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for mc_dropout_pytorch-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 033cdf77e89619959b5ce278745e952a5db7da6015e7974398854bf745c6d1d3
MD5 e85c5a0e8b71a17b6cc6bc5bf1df1a12
BLAKE2b-256 b9cf00a7a9d0982d6ffcc9ed3cddcc72938a5d875670e78b2fe14c91611bd754

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page