Accelerate training of neural networks using importance sampling.
This python package provides a library that accelerates the training of arbitrary neural networks created with Keras using importance sampling.
# Keras imports from importance_sampling.training import ImportanceTraining x_train, y_train, x_val, y_val = load_data() model = create_keras_model() model.compile( optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"] ) ImportanceTraining(model).fit( x_train, y_train, batch_size=32, epochs=10, verbose=1, validation_data=(x_val, y_val) ) model.evaluate(x_val, y_val)
Importance sampling for Deep Learning is an active research field and this library is undergoing development so your mileage may vary.
- Not All Samples Are Created Equal: Deep Learning with Importance Sampling [preprint]
- Biased Importance Sampling for Deep Neural Network Training [preprint]
Dependencies & Installation
Normally if you already have a functional Keras installation you just need to pip install keras-importance-sampling.
- Keras > 2
- A Keras backend among Tensorflow, Theano and CNTK
- matplotlib, seaborn, scikit-learn are optional (used by the plot scripts)
In the examples folder you can find some Keras examples that have been edited to use importance sampling.
In this section we will showcase part of the API that can be used to train neural networks with importance sampling.
# Import what is needed to build the Keras model from keras import backend as K from keras.layers import Dense, Activation, Flatten from keras.models import Sequential # Import a toy dataset and the importance training from importance_sampling.datasets import MNIST from importance_sampling.training import ImportanceTraining def create_nn(): """Build a simple fully connected NN""" model = Sequential([ Flatten(input_shape=(28, 28, 1)), Dense(40, activation="tanh"), Dense(40, activation="tanh"), Dense(10), Activation("softmax") # Needs to be separate to automatically # get the preactivation outputs ]) model.compile( optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"] ) return model if __name__ == "__main__": # Load the data dataset = MNIST() x_train, y_train = dataset.train_data[:] x_test, y_test = dataset.test_data[:] # Create the NN and keep the initial weights model = create_nn() weights = model.get_weights() # Train with uniform sampling K.set_value(model.optimizer.lr, 0.01) model.fit( x_train, y_train, batch_size=64, epochs=10, validation_data=(x_test, y_test) ) # Train with importance sampling model.set_weights(weights) K.set_value(model.optimizer.lr, 0.01) ImportanceTraining(model).fit( x_train, y_train, batch_size=64, epochs=2, validation_data=(x_test, y_test) )
Using the script
The following terminal commands train a small VGG-like network to ~0.65% error on MNIST (the numbers are from a CPU). .. code:
$ # Train a small cnn with mnist for 500 mini-batches using importance $ # sampling with bias to achieve ~ 0.65% error (on the CPU). $ time ./importance_sampling.py \ > small_cnn \ > oracle-gnorm \ > model \ > predicted \ > mnist \ > /tmp/is \ > --hyperparams 'batch_size=i128;lr=f0.003;lr_reductions=I10000' \ > --train_for 500 --validate_every 500 real 6m16.476s user 24m46.800s sys 5m36.592s $ $ # And with uniform sampling to achieve ~ 0.9% error. $ time ./importance_sampling.py \ > small_cnn \ > oracle-loss \ > uniform \ > unweighted \ > mnist \ > /tmp/uniform \ > --hyperparams 'batch_size=i128;lr=f0.003;lr_reductions=I10000' \ > --train_for 3000 --validate_every 3000 real 10m36.836s user 47m36.316s sys 7m14.412s