Skip to main content

An implementation of weight pruning and re-initialization for Keras to improve training performance

Project description


Stroke implements the concepts of weight pruning and re-initialization. The goal of the Stroke callback is to re-initialize weights that have begun to contribute to overfitting, or weights that are effectively 0.

Keep in mind that using Stroke on larger models may introduce significant slowdown while training.

Parameters of the callback are:

  • stroke - Implements weight re-initialization, will randomly re-initialize a percentage of weights between the weight bounds. (default_value = True)
  • minweight - the minimum value of the random weights to be generated. (default value = -.05)
  • maxweight - the maximum value of the random weights to be generated. (default value = .05)
  • volatility_ratio - the percentage of weights you would like to re-initialize. Must meet the condition .01 < volatility_ratio < 1.0 (default value = .05)
  • s_cutoff - cuts off Stroke once s_cutoff is greater than the number of completed epochs, unless s_cutoff = -1. (default value = -1)
  • decay - volatility_ratio will be multiplied by decay at the end of every epoch, after weights have been stricken. (default value = None)
  • index - the index of a layer within the model that you'd like to randomize the weights of. This will prevent randomization of all other layers. (default value = None)
  • verbose - prints the model/layer name and the percentage of weights that were randomized. (default value = False)
  • pruning - implements weight pruning, will set weights between the pruning bounds to 0. (default value = True)
  • pruningmin - lower bound for weight pruning. This usually shouldn't be altered. (default value = 0.0)
  • pruningmax - upper bound for weight pruning. (default value = .02)
  • p_cutoff - cuts off pruning once p_cutoff is greater than the number of completed epochs, unless p_cutoff = -1. (default value = -1)

An implementation of the Stroke callback on an MNIST classification model can be seen below:

from keras.models import Sequential
from keras.layers import Dense, Conv2D, MaxPool2D, Flatten
from kerastroke import Stroke

model = Sequential()

model.add(Conv2D(32, 3, 3, input_shape = (28,28, 1), activation = 'relu'))
model.add(MaxPool2D(pool_size = (2,2)))

model.add(Conv2D(32,3,3, activation = 'relu'))
model.add(MaxPool2D(pool_size = (2,2)))


model.add(Dense(output_dim = 128, init = 'uniform', activation = 'relu'))

model.add(Dense(10, init = 'uniform', activation = 'sigmoid'))

model.compile(optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])

history =, y_train,
                    callbacks=[Stroke(decay=.9, pruning=False)])

Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for kerastroke, version 1.3.0
Filename, size File type Python version Upload date Hashes
Filename, size kerastroke-1.3.0-py3-none-any.whl (4.2 kB) File type Wheel Python version py3 Upload date Hashes View hashes
Filename, size kerastroke-1.3.0.tar.gz (3.7 kB) File type Source Python version None Upload date Hashes View hashes

Supported by

Elastic Elastic Search Pingdom Pingdom Monitoring Google Google BigQuery Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page