Deep ensembles for Keras
Project description
Keras Deep Ensemble Implementation
This is an implementation of Lakshminarayanan et al. deep ensembles paper in Keras. It creates an ensemble of models that can predict uncertainty. You provide a model which outputs two values (mean, variance) and the library will ensemble and resample your data for ensemble training. We have made some modifications, which will be described more fully in an upcoming paper. Please no scoops.
Install
pip install kdeepensemble
Quickstart
This example makes a Keras model inside a function and then reshapes data for ensemble training. Notice a DeepEnsemble
model acts just like a Keras model.
import kdens
import tensorflow as tf
# this is where you define your model
def make_model():
i = tf.keras.Input((None,))
x = tf.keras.layers.Dense(10, activation="relu")
mean = tf.keras.layers.Dense(1)(x)
# this activation makes our variance strictly positive
var = tf.keras.layers.Dense(1, activation='softplus')(x)
out = tf.squeeze(tf.stack([muhat, stdhat], axis=-1))
model = tf.keras.Model(inputs=inputs, outputs=out)
return model
# prepare data for ensemble training
resampled_idx = kdens.resample(y)
x_train = x[idx]
y_train = y[idx]
deep_ens = kdens.DeepEnsemble(make_model)
deep_ens.compile(loss=kdens.neg_ll)
deep_ens.fit(x_train, y_train)
deep_ens(x)
Model Output
The output is shape (N, 3)
, where the last axis is mean, variance, and epistemic variance. Epistemic variance is from disagreements from models and reflects model uncertainty. The variance includes both epistemic and aleatoric variance. It represents the models best estimate of uncertainty.
Saving/Loading
You can serialize the model with model.save
, but note that training will not be abel to continue. To continue training, use the load_weights
and save_weights
methods.
Tensorflow Dataset
You can use map_reshape
when working with a Tensorflow dataset. It will If your data is already batched, add the is_batched=True
argument.
# data is a tf.data.Dataset
data = data.map(kdens.map_reshape()).batch(8)
deep_ens = kdens.DeepEnsemble(make_model)
deep_ens.compile(loss=kdens.neg_ll)
deep_ens.fit(data)
Note that map_reshape
will not resample, just reshape. If you would like to balance your labels, consider rejection_sample
.
Working with multiple inputs
This library does support Keras models that have multiple inputs with the following restrictions:
- The inputs must be tuples (lists and dicts are not supported).
- The adversarial step will be done on only the first element of the input tuple.
Here's an example
# make a model that takes three inputs as a tuple
x = np.random.randn(100, 10).astype(np.float32)
t = np.random.randn(100, 10, 3).astype(np.float32)
z = np.random.randn(100, 10, 3).astype(np.float32)
y = np.random.randn(100).astype(np.float32)
# can still use map_reshape with tuples
data = tf.data.Dataset.from_tensor_slices(
((x, t, z), y)).map(map_reshape()).batch(8)
def build():
xi = tf.keras.layers.Input(shape=(10,))
ti = tf.keras.layers.Input(shape=(10, 3))
zi = tf.keras.layers.Input(shape=(10, 3))
x = tf.keras.layers.Dense(2)(xi)
return tf.keras.Model(inputs=(xi, ti, zi), outputs=x)
deep_ens = kdens.DeepEnsemble(make_model)
deep_ens.compile(loss=kdens.neg_ll)
deep_ens.fit(data, epochs=2)
deep_ens.evaluate((x, t, z), y)
API
Citation
Deep ensemble paper:
@article{lakshminarayanan2017simple,
title={Simple and scalable predictive uncertainty estimation using deep ensembles},
author={Lakshminarayanan, Balaji and Pritzel, Alexander and Blundell, Charles},
journal={Advances in neural information processing systems},
volume={30},
year={2017}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file kdeepensemble-0.3.2.tar.gz
.
File metadata
- Download URL: kdeepensemble-0.3.2.tar.gz
- Upload date:
- Size: 7.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2eea9f542ad9bd6e903fed1bffda94978795275d7a423ee8e13ff2ecbd2a5398 |
|
MD5 | 5acef60b18f98c60b4b2ea5d6f7909b2 |
|
BLAKE2b-256 | 1f56f6e9a9993a67e254335e92ecbacaaaa6c3c1289e809f806ddfc37266e493 |
File details
Details for the file kdeepensemble-0.3.2-py3-none-any.whl
.
File metadata
- Download URL: kdeepensemble-0.3.2-py3-none-any.whl
- Upload date:
- Size: 8.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 97211aab2c7fab6a33afcb74d9ff064c9c45bef52bb7dba0f24232b832943607 |
|
MD5 | 1644ceadda5e5b841deaf0a937fecd90 |
|
BLAKE2b-256 | efd050e520cebe3b22baeff22758e18b438c307b06d007948989c226591599df |