Skip to main content

An implementation of weight pruning and re-initialization for Keras to improve training performance

Project description

Stroke

Stroke implements the concepts of weight pruning and re-initialization. The goal of the Stroke callback is to re-initialize weights that have begun to contribute to overfitting, or weights that are effectively 0.

Keep in mind that using Stroke on larger models may introduce significant slowdown while training.

Parameters of the callback are:

  • minweight - the minimum value of the random weights to be generated. (default value = -.05)
  • maxweight - the maximum value of the random weights to be generated. (default value = .05)
  • volatility_ratio - the percentage of weights you would like to re-initialize. (default value = .05)
  • decay - volatility_ratio will be multiplied by decay at the end of every epoch, after weights have been stricken. (default value = None)
  • index - the index of a layer within the model that you'd like to randomize the weights of. This will prevent randomization of all other layers. (default value = None)
  • verbose - Prints the model/layer name and the percentage of weights that were randomized. (default value = False)
  • pruning - Implements weight pruning, will set weights between the pruning bounds to 0. (default value = True)
  • pruningmin - Lower bound for weight pruning. This usually shouldn't be altered. (default value = 0.0)
  • pruningmax - Upper bound for weight pruning. (default value = .02)

An implementation of the Stroke callback on an MNIST classification model can be seen below:

from keras.models import Sequential
from keras.layers import Dense, Conv2D, MaxPool2D, Flatten
from kerastroke import Stroke

model = Sequential()

model.add(Conv2D(32, 3, 3, input_shape = (28,28, 1), activation = 'relu'))
model.add(MaxPool2D(pool_size = (2,2)))

model.add(Conv2D(32,3,3, activation = 'relu'))
model.add(MaxPool2D(pool_size = (2,2)))

model.add(Flatten())

model.add(Dense(output_dim = 128, init = 'uniform', activation = 'relu'))

model.add(Dense(10, init = 'uniform', activation = 'sigmoid'))

model.compile(optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])

history = model.fit(x_train, y_train,
                    batch_size=64,
                    epochs=1,
                    steps_per_epoch=5,
                    verbose=0,
                    callbacks=[Stroke()])

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kerastroke-1.2.0.tar.gz (3.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

kerastroke-1.2.0-py3-none-any.whl (4.0 kB view details)

Uploaded Python 3

File details

Details for the file kerastroke-1.2.0.tar.gz.

File metadata

  • Download URL: kerastroke-1.2.0.tar.gz
  • Upload date:
  • Size: 3.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.21.0 setuptools/45.1.0 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.5

File hashes

Hashes for kerastroke-1.2.0.tar.gz
Algorithm Hash digest
SHA256 359ed81339ad27b498fd57da96c70263bae3360c569d6be31833cdd77d20a6f5
MD5 0b313b1817e3b0d56130737abc750a33
BLAKE2b-256 4006596b3c168daf49e921143dcd363c154ddc44b1d59341d9b2d15029a933e2

See more details on using hashes here.

File details

Details for the file kerastroke-1.2.0-py3-none-any.whl.

File metadata

  • Download URL: kerastroke-1.2.0-py3-none-any.whl
  • Upload date:
  • Size: 4.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.21.0 setuptools/45.1.0 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.5

File hashes

Hashes for kerastroke-1.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e7d0c148f455440e68261ed5cbb07b1458507c957ac3aa8148f29da596bf91d1
MD5 a2c124d8d0295211d925a053289ce46b
BLAKE2b-256 49cc2a4ac38842bfe4d12de6704074e80b411b99562997ea9eb8e0751f25f99f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page