Accelerate training of neural networks using importance sampling.

## Project description

This python package provides a library that accelerates the training of
arbitrary neural networks created with Keras using
**importance sampling**.

# Keras imports from importance_sampling.training import ImportanceTraining x_train, y_train, x_val, y_val = load_data() model = create_keras_model() model.compile( optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"] ) ImportanceTraining(model).fit( x_train, y_train, batch_size=32, epochs=10, verbose=1, validation_data=(x_val, y_val) ) model.evaluate(x_val, y_val)

Importance sampling for Deep Learning is an active research field and this library is undergoing development so your mileage may vary.

## Relevant Research

**Ours**

- Not All Samples Are Created Equal: Deep Learning with Importance Sampling [preprint]
- Biased Importance Sampling for Deep Neural Network Training [preprint]

**By others**

## Dependencies & Installation

Normally if you already have a functional Keras installation you just need to
`pip install keras-importance-sampling`.

`Keras`> 2- A Keras backend among
*Tensorflow*,*Theano*and*CNTK* `blinker``numpy``matplotlib`,`seaborn`,`scikit-learn`are optional (used by the plot scripts)

## Documentation

The module has a dedicated documentation site but you can also read the source code and the examples to get an idea of how the library should be used and extended.

## Examples

In the `examples` folder you can find some Keras examples that have been edited
to use importance sampling.

### Code examples

In this section we will showcase part of the API that can be used to train neural networks with importance sampling.

# Import what is needed to build the Keras model from keras import backend as K from keras.layers import Dense, Activation, Flatten from keras.models import Sequential # Import a toy dataset and the importance training from importance_sampling.datasets import MNIST from importance_sampling.training import ImportanceTraining def create_nn(): """Build a simple fully connected NN""" model = Sequential([ Flatten(input_shape=(28, 28, 1)), Dense(40, activation="tanh"), Dense(40, activation="tanh"), Dense(10), Activation("softmax") # Needs to be separate to automatically # get the preactivation outputs ]) model.compile( optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"] ) return model if __name__ == "__main__": # Load the data dataset = MNIST() x_train, y_train = dataset.train_data[:] x_test, y_test = dataset.test_data[:] # Create the NN and keep the initial weights model = create_nn() weights = model.get_weights() # Train with uniform sampling K.set_value(model.optimizer.lr, 0.01) model.fit( x_train, y_train, batch_size=64, epochs=10, validation_data=(x_test, y_test) ) # Train with importance sampling model.set_weights(weights) K.set_value(model.optimizer.lr, 0.01) ImportanceTraining(model).fit( x_train, y_train, batch_size=64, epochs=2, validation_data=(x_test, y_test) )

### Using the script

The following terminal commands train a small VGG-like network to ~0.65% error on MNIST (the numbers are from a CPU). .. code:

$ # Train a small cnn with mnist for 500 mini-batches using importance $ # sampling with bias to achieve ~ 0.65% error (on the CPU). $ time ./importance_sampling.py \ > small_cnn \ > oracle-gnorm \ > model \ > predicted \ > mnist \ > /tmp/is \ > --hyperparams 'batch_size=i128;lr=f0.003;lr_reductions=I10000' \ > --train_for 500 --validate_every 500 real 6m16.476s user 24m46.800s sys 5m36.592s $ $ # And with uniform sampling to achieve ~ 0.9% error. $ time ./importance_sampling.py \ > small_cnn \ > oracle-loss \ > uniform \ > unweighted \ > mnist \ > /tmp/uniform \ > --hyperparams 'batch_size=i128;lr=f0.003;lr_reductions=I10000' \ > --train_for 3000 --validate_every 3000 real 10m36.836s user 47m36.316s sys 7m14.412s

## Project details

## Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Filename, size & hash SHA256 hash help | File type | Python version | Upload date |
---|---|---|---|

keras-importance-sampling-0.8.tar.gz (40.5 kB) Copy SHA256 hash SHA256 | Source | None |