Deep Ensembles - Pytorch
Project description
Deep Ensembles, in Pytorch
Implementation of Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles (NeurIPS 2017) in Pytorch.
A dead-simple, non-Bayesian approach to uncertainty quantification: train M networks independently and aggregate their predictions as a mixture of Gaussians. Matches or beats approximate Bayesian methods at a fraction of the complexity.
Install
$ pip install deep-ensembles-pytorch
Usage
Regression with Uncertainty
import torch
from deep_ensembles_pytorch import DeepEnsemble, TensorDataset, Trainer
# build ensemble of 5 members — paper default
ensemble = DeepEnsemble(
dim_in = 1,
dim_out = 1,
num_members = 5,
hidden_dim = 100,
depth = 3,
use_adversarial_training = True, # FGSM smoothing (Sec. 3)
adversarial_eps = 0.01,
)
# training
X = torch.randn(1000, 1)
y = X.sin() + 0.1 * torch.randn_like(X)
loss = ensemble(X, y)
loss.backward()
# inference — returns mixture-of-Gaussians aggregate
pred = ensemble(X)
pred.mean # (B, 1) — predictive mean
pred.variance # (B, 1) — predictive variance (epistemic + aleatoric)
Classification with Uncertainty
from deep_ensembles_pytorch import DeepEnsembleClassifier
classifier = DeepEnsembleClassifier(
dim_in = 784,
num_classes = 10,
num_members = 5,
hidden_dim = 200,
depth = 3,
)
X = torch.randn(32, 784)
target = torch.randint(0, 10, (32,))
loss = classifier(X, target)
loss.backward()
# inference
pred = classifier(X)
pred.probs # (B, C) — ensemble-averaged softmax
pred.variance # (B, 1) — predictive entropy (higher = more uncertain)
Full Training with Trainer
import torch
from deep_ensembles_pytorch import DeepEnsemble, TensorDataset, Trainer
ensemble = DeepEnsemble(
dim_in = 13, # e.g. Boston Housing
dim_out = 1,
num_members = 5,
hidden_dim = 100,
depth = 3,
)
X = torch.randn(500, 13)
y = torch.randn(500, 1)
dataset = TensorDataset(X, y)
trainer = Trainer(
ensemble,
dataset,
train_batch_size = 100,
train_lr = 1e-3,
train_num_steps = 10_000,
ema_decay = 0.995,
amp = False,
)
trainer.train()
# checkpoints saved to ./results/
Multi-GPU
$ accelerate config
$ accelerate launch train.py
Accessing Individual Member Predictions
# useful for visualising the ensemble spread
member_preds = ensemble.sample_predictions(X)
for i, pred in enumerate(member_preds):
print(f'member {i}: μ={pred.mean[:3]}, σ²={pred.log_var.exp()[:3]}')
How It Works
Three design choices make Deep Ensembles both simple and powerful:
-
Proper scoring rule — each member minimises the Gaussian NLL rather than MSE, forcing the network to predict both mean and variance:
L(θ) = 0.5 · [log σ²_θ(x) + (y − μ_θ(x))² / σ²_θ(x)] -
Random initialisation diversity — members are independently initialised; no shared weights, no shared data subsets (unlike Bagging).
-
Mixture-of-Gaussians aggregation — at inference the ensemble forms a richer predictive distribution than any single member:
μ* = (1/M) Σ μ_m σ*² = (1/M) Σ (σ²_m + μ²_m) − μ*²
An optional FGSM adversarial step smooths each member's predictive distribution during training, shown to improve calibration in the paper.
Citation
@article{lakshminarayanan2017simple,
title = {Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles},
author = {Balaji Lakshminarayanan and Alexander Pritzel and Charles Blundell},
journal = {Advances in Neural Information Processing Systems},
year = {2017},
url = {https://arxiv.org/abs/1612.01474}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file deep_ensembles_pytorch-0.1.2.tar.gz.
File metadata
- Download URL: deep_ensembles_pytorch-0.1.2.tar.gz
- Upload date:
- Size: 7.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0b640e2d0629e8d05d40fdc196ddaff207fac861c8fb4f27c8d43c0d95043a99
|
|
| MD5 |
e5f56961ce5bd8a1b362683fa3b4720f
|
|
| BLAKE2b-256 |
04a8e8c589396f8684490e4bf7a359ea0f2336df27a61cb27344d2b7cbb7401a
|
File details
Details for the file deep_ensembles_pytorch-0.1.2-py3-none-any.whl.
File metadata
- Download URL: deep_ensembles_pytorch-0.1.2-py3-none-any.whl
- Upload date:
- Size: 8.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f8aba727002357889208e10b1949e0c9d4465d14c57868b82f5bf5c776ba5604
|
|
| MD5 |
77844049d5d7c6298d0041c560803824
|
|
| BLAKE2b-256 |
955d5daf256682701fe6bb6910c98380fe2102fd42f48d513907e8d6ce0eeefa
|