Skip to main content

Deep Ensembles - Pytorch

Project description

Deep Ensembles, in Pytorch

Implementation of Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles (NeurIPS 2017) in Pytorch.

A dead-simple, non-Bayesian approach to uncertainty quantification: train M networks independently and aggregate their predictions as a mixture of Gaussians. Matches or beats approximate Bayesian methods at a fraction of the complexity.

Install

$ pip install deep-ensembles-pytorch

Usage

Regression with Uncertainty

import torch
from deep_ensembles_pytorch import DeepEnsemble, TensorDataset, Trainer

# build ensemble of 5 members — paper default
ensemble = DeepEnsemble(
    dim_in = 1,
    dim_out = 1,
    num_members = 5,
    hidden_dim = 100,
    depth = 3,
    use_adversarial_training = True,   # FGSM smoothing (Sec. 3)
    adversarial_eps = 0.01,
)

# training
X = torch.randn(1000, 1)
y = X.sin() + 0.1 * torch.randn_like(X)

loss = ensemble(X, y)
loss.backward()

# inference — returns mixture-of-Gaussians aggregate
pred = ensemble(X)
pred.mean      # (B, 1)  — predictive mean
pred.variance  # (B, 1)  — predictive variance (epistemic + aleatoric)

Classification with Uncertainty

from deep_ensembles_pytorch import DeepEnsembleClassifier

classifier = DeepEnsembleClassifier(
    dim_in = 784,
    num_classes = 10,
    num_members = 5,
    hidden_dim = 200,
    depth = 3,
)

X = torch.randn(32, 784)
target = torch.randint(0, 10, (32,))

loss = classifier(X, target)
loss.backward()

# inference
pred = classifier(X)
pred.probs     # (B, C)  — ensemble-averaged softmax
pred.variance  # (B, 1)  — predictive entropy (higher = more uncertain)

Full Training with Trainer

import torch
from deep_ensembles_pytorch import DeepEnsemble, TensorDataset, Trainer

ensemble = DeepEnsemble(
    dim_in = 13,   # e.g. Boston Housing
    dim_out = 1,
    num_members = 5,
    hidden_dim = 100,
    depth = 3,
)

X = torch.randn(500, 13)
y = torch.randn(500, 1)
dataset = TensorDataset(X, y)

trainer = Trainer(
    ensemble,
    dataset,
    train_batch_size = 100,
    train_lr = 1e-3,
    train_num_steps = 10_000,
    ema_decay = 0.995,
    amp = False,
)

trainer.train()
# checkpoints saved to ./results/

Multi-GPU

$ accelerate config
$ accelerate launch train.py

Accessing Individual Member Predictions

# useful for visualising the ensemble spread
member_preds = ensemble.sample_predictions(X)

for i, pred in enumerate(member_preds):
    print(f'member {i}: μ={pred.mean[:3]}, σ²={pred.log_var.exp()[:3]}')

How It Works

Three design choices make Deep Ensembles both simple and powerful:

  1. Proper scoring rule — each member minimises the Gaussian NLL rather than MSE, forcing the network to predict both mean and variance:

    L(θ) = 0.5 · [log σ²_θ(x) + (y − μ_θ(x))² / σ²_θ(x)]
    
  2. Random initialisation diversity — members are independently initialised; no shared weights, no shared data subsets (unlike Bagging).

  3. Mixture-of-Gaussians aggregation — at inference the ensemble forms a richer predictive distribution than any single member:

    μ* = (1/M) Σ μ_m
    σ*² = (1/M) Σ (σ²_m + μ²_m) − μ*²
    

An optional FGSM adversarial step smooths each member's predictive distribution during training, shown to improve calibration in the paper.

Citation

@article{lakshminarayanan2017simple,
    title   = {Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles},
    author  = {Balaji Lakshminarayanan and Alexander Pritzel and Charles Blundell},
    journal = {Advances in Neural Information Processing Systems},
    year    = {2017},
    url     = {https://arxiv.org/abs/1612.01474}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deep_ensembles_pytorch-0.1.2.tar.gz (7.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

deep_ensembles_pytorch-0.1.2-py3-none-any.whl (8.0 kB view details)

Uploaded Python 3

File details

Details for the file deep_ensembles_pytorch-0.1.2.tar.gz.

File metadata

  • Download URL: deep_ensembles_pytorch-0.1.2.tar.gz
  • Upload date:
  • Size: 7.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for deep_ensembles_pytorch-0.1.2.tar.gz
Algorithm Hash digest
SHA256 0b640e2d0629e8d05d40fdc196ddaff207fac861c8fb4f27c8d43c0d95043a99
MD5 e5f56961ce5bd8a1b362683fa3b4720f
BLAKE2b-256 04a8e8c589396f8684490e4bf7a359ea0f2336df27a61cb27344d2b7cbb7401a

See more details on using hashes here.

File details

Details for the file deep_ensembles_pytorch-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for deep_ensembles_pytorch-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 f8aba727002357889208e10b1949e0c9d4465d14c57868b82f5bf5c776ba5604
MD5 77844049d5d7c6298d0041c560803824
BLAKE2b-256 955d5daf256682701fe6bb6910c98380fe2102fd42f48d513907e8d6ce0eeefa

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page