Skip to main content

Accelerate training of neural networks using importance sampling.

Project description

This python package provides a library that accelerates the training of arbitrary neural networks created with Keras using importance sampling.

# Keras imports

from import ImportanceTraining

x_train, y_train, x_val, y_val = load_data()
model = create_keras_model()

    x_train, y_train,
    validation_data=(x_val, y_val)

model.evaluate(x_val, y_val)

Importance sampling for Deep Learning is an active research field and this library is undergoing development so your mileage may vary.

Relevant Research


  • Not All Samples Are Created Equal: Deep Learning with Importance Sampling [preprint]
  • Biased Importance Sampling for Deep Neural Network Training [preprint]

By others

  • Stochastic optimization with importance sampling for regularized loss minimization [pdf]
  • Variance reduction in SGD by distributed importance sampling [pdf]

Dependencies & Installation

Normally if you already have a functional Keras installation you just need to pip install keras-importance-sampling.

  • Keras > 2
  • A Keras backend among Tensorflow, Theano and CNTK
  • blinker
  • numpy
  • matplotlib, seaborn, scikit-learn are optional (used by the plot scripts)


The module has a dedicated documentation site but you can also read the source code and the examples to get an idea of how the library should be used and extended.


In the examples folder you can find some Keras examples that have been edited to use importance sampling.

Code examples

In this section we will showcase part of the API that can be used to train neural networks with importance sampling.

# Import what is needed to build the Keras model
from keras import backend as K
from keras.layers import Dense, Activation, Flatten
from keras.models import Sequential

# Import a toy dataset and the importance training
from importance_sampling.datasets import MNIST
from import ImportanceTraining

def create_nn():
    """Build a simple fully connected NN"""
    model = Sequential([
        Flatten(input_shape=(28, 28, 1)),
        Dense(40, activation="tanh"),
        Dense(40, activation="tanh"),
        Activation("softmax") # Needs to be separate to automatically
                              # get the preactivation outputs


    return model

if __name__ == "__main__":
    # Load the data
    dataset = MNIST()
    x_train, y_train = dataset.train_data[:]
    x_test, y_test = dataset.test_data[:]

    # Create the NN and keep the initial weights
    model = create_nn()
    weights = model.get_weights()

    # Train with uniform sampling
    K.set_value(, 0.01)
        x_train, y_train,
        batch_size=64, epochs=10,
        validation_data=(x_test, y_test)

    # Train with importance sampling
    K.set_value(, 0.01)
        x_train, y_train,
        batch_size=64, epochs=2,
        validation_data=(x_test, y_test)

Using the script

The following terminal commands train a small VGG-like network to ~0.65% error on MNIST (the numbers are from a CPU). .. code:

$ # Train a small cnn with mnist for 500 mini-batches using importance
$ # sampling with bias to achieve ~ 0.65% error (on the CPU).
$ time ./ \
>   small_cnn \
>   oracle-gnorm \
>   model \
>   predicted \
>   mnist \
>   /tmp/is \
>   --hyperparams 'batch_size=i128;lr=f0.003;lr_reductions=I10000' \
>   --train_for 500 --validate_every 500
real    6m16.476s
user    24m46.800s
sys     5m36.592s
$ # And with uniform sampling to achieve ~ 0.9% error.
$ time ./ \
>   small_cnn \
>   oracle-loss \
>   uniform \
>   unweighted \
>   mnist \
>   /tmp/uniform \
>   --hyperparams 'batch_size=i128;lr=f0.003;lr_reductions=I10000' \
>   --train_for 3000 --validate_every 3000
real    10m36.836s
user    47m36.316s
sys     7m14.412s

Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Filename, size & hash SHA256 hash help File type Python version Upload date
keras-importance-sampling-0.8.tar.gz (40.5 kB) Copy SHA256 hash SHA256 Source None

Supported by

Elastic Elastic Search Pingdom Pingdom Monitoring Google Google BigQuery Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN SignalFx SignalFx Supporter DigiCert DigiCert EV certificate StatusPage StatusPage Status page