A custom Keras callback to prevent overfitting
Project description
Stroke
While reading about the concept of dropout, I thought about removing weights between layers instead of removing data. So I created a custom Keras callback called "Stroke", which randomizes a set percentage of weights in a model or one of its layers, sort of replicating what happens when a human has a stroke. The goal of the Stroke callback is to re-initialize weights that have begun to contribute to overfitting.
Parameters of the callback are:
model
- the model used in training (Required)minweight
- the minimum value of the random weights to be generated. (default value = -.05)maxweight
- the maximum value of the random weights to be generated. (default value = .05)volatility_ratio
- the percentage of weights you would like to re-initialize. (default value = .1)decay
- if this value is set, the volatility_ratio will be multiplied by decay at the end of every epoch, after weights have been stricken.index
- the index of a layer within the model that you'd like to randomize the weights of. This will prevent randomization of all other layers. (default value = None)verbose
- defaults to False. If set to True, will print the model/layer name and the percentage of weights that were randomized.
An implementation of the Stroke callback on an MNIST classification model can be seen below:
from keras.models import Sequential
from keras.layers import Dense, Conv2D, MaxPool2D, Flatten
from kerastroke import Stroke
model = Sequential()
model.add(Conv2D(32, 3, 3, input_shape = (28,28, 1), activation = 'relu'))
model.add(MaxPool2D(pool_size = (2,2)))
model.add(Conv2D(32,3,3, activation = 'relu'))
model.add(MaxPool2D(pool_size = (2,2)))
model.add(Flatten())
model.add(Dense(output_dim = 128, init = 'uniform', activation = 'relu'))
model.add(Dense(10, init = 'uniform', activation = 'sigmoid'))
model.compile(optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])
_ = model.fit(x_train, y_train,
batch_size=64,
epochs=1,
steps_per_epoch=5,
verbose=0,
callbacks=[Stroke(model)])
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
kerastroke-1.0.2.tar.gz
(3.2 kB
view hashes)
Built Distribution
Close
Hashes for kerastroke-1.0.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2375a05c98c3a5148982ad16fe4d18f54ebdf48663391ebb0acf1ea9058aa19b |
|
MD5 | 218c0b2e5735cc120ffb93f8a2e77a31 |
|
BLAKE2b-256 | 566d1df391af33a9f864e249375b638dbd3ce1ea2fa8e81978efbd9e7e4ea232 |