Accelerate training of neural networks using importance sampling.
Project description
This python package provides a library that accelerates the training of arbitrary neural networks created with Keras using importance sampling.
# Keras imports
from importance_sampling.training import ImportanceTraining, \
ApproximateImportanceTraining
x_train, y_train, x_val, y_val = load_data()
model = create_keras_model()
model.compile(
optimizer="adam",
loss="categorical_crossentropy",
metrics=["accuracy"]
)
ImportanceTraining(model).fit(
x_train, y_train,
batch_size=32,
epochs=10,
verbose=1,
validation_data=(x_val, y_val)
)
model.evaluate(x_val, y_val)
Importance sampling for Deep Learning is an active research field and this library is undergoing development so your mileage may vary.
Relevant Research
Ours
Biased Importance Sampling for Deep Neural Network Training [preprint]
By others
Dependencies & Installation
Normally if you already have a functional Keras installation you just need to pip install keras-importance-sampling.
Keras > 2
A Keras backend among Tensorflow, Theano and CNTK
transparent-keras
blinker
numpy
matplotlib, seaborn, scikit-learn are optional (used by the plot scripts)
Documentation
The module has a dedicated documentation site but you can also read the source code and the examples to get an idea of how the library should be used and extended.
Examples
In the examples folder you can find some Keras examples that have been edited (minimally) to use importance sampling.
Code examples
In this section we will showcase part of the API that can be used to train neural networks with importance sampling.
# Import what is needed to build the Keras model
from keras import backend as K
from keras.layers import Dense
from keras.models import Sequential
# Import a toy dataset and the importance training
from importance_sampling.datasets import CanevetICML2016
from importance_sampling.training import ImportanceTraining
def create_nn():
"""Build a simple fully connected NN"""
model = Sequential([
Dense(40, activation="tanh", input_shape=(2,)),
Dense(40, activation="tanh"),
Dense(1, activation="sigmoid")
])
model.compile(
optimizer="adam",
loss="binary_crossentropy",
metrics=["accuracy"]
)
return model
if __name__ == "__main__":
# Load the data
dataset = CanevetICML2016(N=1024)
x_train, y_train = dataset.train_data[:]
x_test, y_test = dataset.test_data[:]
y_train, y_test = y_train.argmax(axis=1), y_test.argmax(axis=1)
# Create the NN and keep the initial weights
model = create_nn()
weights = model.get_weights()
# Train with uniform sampling
K.set_value(model.optimizer.lr, 0.01)
model.fit(
x_train, y_train,
batch_size=64, epochs=10,
validation_data=(x_test, y_test)
)
# Train with biased importance sampling
model.set_weights(weights)
K.set_value(model.optimizer.lr, 0.01)
ImportanceTraining(model, forward_batch_size=1024).fit(
x_train, y_train,
batch_size=64, epochs=3,
validation_data=(x_test, y_test)
)
Using the script
The following terminal commands train a small VGG-like network to ~0.55% error on MNIST (the numbers are from a CPU). It is not optimized, it just showcases that with importance sampling 6 times less iterations are required in this case.
$ # Train a small cnn with mnist for 500 mini-batches using importance
$ # sampling with bias to achieve ~ 0.55% error (on the CPU)
$ time ./importance_sampling.py \
> small_cnn \
> oracle-loss \
> model \
> predicted \
> mnist \
> /tmp/is \
> --hyperparams 'batch_size=i128;lr=f0.003;lr_reductions=I10000;k=f0.5' \
> --train_for 500 --validate_every 500
real 6m16.476s
user 24m46.800s
sys 5m36.592s
$
$ # And with uniform sampling to achieve the same accuracy (learning rate is
$ # smaller because with uniform sampling the variance is too big)
$ time ./importance_sampling.py \
> small_cnn \
> oracle-loss \
> uniform \
> unweighted \
> mnist \
> /tmp/uniform \
> --hyperparams 'batch_size=i128;lr=f0.001;lr_reductions=I1000' \
> --train_for 3000 --validate_every 3000
real 10m36.836s
user 47m36.316s
sys 7m14.412s
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Hashes for keras-importance-sampling-0.3.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4617c9ecba024d6daf7820bad0f00a88d3dde739f50e060ac57ebbad6de4a9e2 |
|
MD5 | f8c1e0432a6cb8751d6769d3ba7e97d3 |
|
BLAKE2b-256 | 5849b5d22c7a556a524e008e38c448b4122b586303d2f901bc2f1934894f9276 |