Skip to main content

Deep Ensembles - Pytorch

Project description

Deep Ensembles, in Pytorch

Implementation of Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles (NeurIPS 2017) in Pytorch.

A dead-simple, non-Bayesian approach to uncertainty quantification: train M networks independently and aggregate their predictions as a mixture of Gaussians. Matches or beats approximate Bayesian methods at a fraction of the complexity.

Install

$ pip install deep-ensembles-pytorch

Usage

Regression with Uncertainty

import torch
from deep_ensembles_pytorch import DeepEnsemble, TensorDataset, Trainer

# build ensemble of 5 members — paper default
ensemble = DeepEnsemble(
    dim_in = 1,
    dim_out = 1,
    num_members = 5,
    hidden_dim = 100,
    depth = 3,
    use_adversarial_training = True,   # FGSM smoothing (Sec. 3)
    adversarial_eps = 0.01,
)

# training
X = torch.randn(1000, 1)
y = X.sin() + 0.1 * torch.randn_like(X)

loss = ensemble(X, y)
loss.backward()

# inference — returns mixture-of-Gaussians aggregate
pred = ensemble(X)
pred.mean      # (B, 1)  — predictive mean
pred.variance  # (B, 1)  — predictive variance (epistemic + aleatoric)

Classification with Uncertainty

from deep_ensembles_pytorch import DeepEnsembleClassifier

classifier = DeepEnsembleClassifier(
    dim_in = 784,
    num_classes = 10,
    num_members = 5,
    hidden_dim = 200,
    depth = 3,
)

X = torch.randn(32, 784)
target = torch.randint(0, 10, (32,))

loss = classifier(X, target)
loss.backward()

# inference
pred = classifier(X)
pred.probs     # (B, C)  — ensemble-averaged softmax
pred.variance  # (B, 1)  — predictive entropy (higher = more uncertain)

Full Training with Trainer

import torch
from deep_ensembles_pytorch import DeepEnsemble, TensorDataset, Trainer

ensemble = DeepEnsemble(
    dim_in = 13,   # e.g. Boston Housing
    dim_out = 1,
    num_members = 5,
    hidden_dim = 100,
    depth = 3,
)

X = torch.randn(500, 13)
y = torch.randn(500, 1)
dataset = TensorDataset(X, y)

trainer = Trainer(
    ensemble,
    dataset,
    train_batch_size = 100,
    train_lr = 1e-3,
    train_num_steps = 10_000,
    ema_decay = 0.995,
    amp = False,
)

trainer.train()
# checkpoints saved to ./results/

Multi-GPU

$ accelerate config
$ accelerate launch train.py

Accessing Individual Member Predictions

# useful for visualising the ensemble spread
member_preds = ensemble.sample_predictions(X)

for i, pred in enumerate(member_preds):
    print(f'member {i}: μ={pred.mean[:3]}, σ²={pred.log_var.exp()[:3]}')

How It Works

Three design choices make Deep Ensembles both simple and powerful:

  1. Proper scoring rule — each member minimises the Gaussian NLL rather than MSE, forcing the network to predict both mean and variance:

    L(θ) = 0.5 · [log σ²_θ(x) + (y − μ_θ(x))² / σ²_θ(x)]
    
  2. Random initialisation diversity — members are independently initialised; no shared weights, no shared data subsets (unlike Bagging).

  3. Mixture-of-Gaussians aggregation — at inference the ensemble forms a richer predictive distribution than any single member:

    μ* = (1/M) Σ μ_m
    σ*² = (1/M) Σ (σ²_m + μ²_m) − μ*²
    

An optional FGSM adversarial step smooths each member's predictive distribution during training, shown to improve calibration in the paper.

Citation

@article{lakshminarayanan2017simple,
    title   = {Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles},
    author  = {Balaji Lakshminarayanan and Alexander Pritzel and Charles Blundell},
    journal = {Advances in Neural Information Processing Systems},
    year    = {2017},
    url     = {https://arxiv.org/abs/1612.01474}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deep_ensembles_pytorch-0.1.0.tar.gz (7.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

deep_ensembles_pytorch-0.1.0-py3-none-any.whl (7.9 kB view details)

Uploaded Python 3

File details

Details for the file deep_ensembles_pytorch-0.1.0.tar.gz.

File metadata

  • Download URL: deep_ensembles_pytorch-0.1.0.tar.gz
  • Upload date:
  • Size: 7.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for deep_ensembles_pytorch-0.1.0.tar.gz
Algorithm Hash digest
SHA256 0fd79b548c66661d9a5e7708d437e4eed4a68a86ba634ffe809fe64929db095a
MD5 25cf25d64012b29e733bdd0e59ff1902
BLAKE2b-256 1744ccf183ca80e9b29b043f659a4976ec7df057c0a920fa4f4c8a68cb2bad71

See more details on using hashes here.

File details

Details for the file deep_ensembles_pytorch-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for deep_ensembles_pytorch-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 762e576844dba24a8aacb01df47ae4b03ba01334971f616d198c5e46c8ffcb57
MD5 9c1e4cff624a7ff3acba34ee7817fa4b
BLAKE2b-256 f27cfd7d9df89db9002627df7ebc72293cd5b2b210837ebf2142d563478c5d7a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page