Skip to main content

Simple stochastic weight averaging callback for Keras.

Project description

Keras SWA - Stochastic Weight Averaging

PyPI version License

This is an implemention of SWA for Keras and TF-Keras. It currently only implements the constant learning rate scheduler, the cyclic learning rate described in the paper will come soon.

Introduction

Stochastic weight averaging (SWA) is build upon the same principle as snapshot ensembling and fast geometric ensembling. The idea is that averaging select stages of training can lead to better models. Where as the two former methods average by sampling and ensembling models, SWA instead average weights. This has been shown to give comparable improvements confined into a single model.

Illustration

Paper

Installation

pip install keras-swa

Batch Normalization

Last epoch will be a forward pass, i.e. have learning rate set to zero, for models with batch normalization. This is due to the fact that batch normalization uses the running mean and variance of it's preceding layer to make a normalization. SWA will offset this normalization by suddenly changing the weights in the end of training. Therefore it is necessary for the last epoch to be used to reset and recalculate batch normalization for the updated weights.

SWA

Keras callback object for SWA.

Arguments

start_epoch - Starting epoch for SWA.

lr_schedule - Learning rate scheduler (optional), 'constant' for the non-cyclic scheduler from the paper.

swa_lr - Minimum learning rate for scheduler.

batch_size - Training batch size, only required for models that use batch normalization when fit with a generator.

verbose - Verbosity mode, 0 or 1.

Example

For Keras

from sklearn.datasets.samples_generator import make_blobs
from keras.utils import to_categorical
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD

from swa.keras import SWA
 
# make dataset
X, y = make_blobs(n_samples=1000, 
                  centers=3, 
                  n_features=2, 
                  cluster_std=2, 
                  random_state=2)

y = to_categorical(y)

# build model
model = Sequential()
model.add(Dense(50, input_dim=2, activation='relu'))
model.add(Dense(3, activation='softmax'))

model.compile(loss='categorical_crossentropy', 
              optimizer=SGD(learning_rate=0.1))

epochs = 100
start_epoch = 75

# define swa callback
swa = SWA(start_epoch=start_epoch, 
          lr_schedule='constant', 
          swa_lr=0.01, 
          verbose=1)

# train
model.fit(X, y, epochs=epochs, verbose=1, callbacks=[swa])

Or for Keras in Tensorflow

from sklearn.datasets.samples_generator import make_blobs
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import SGD

from swa.tfkeras import SWA

# make dataset
X, y = make_blobs(n_samples=1000, 
                  centers=3, 
                  n_features=2, 
                  cluster_std=2, 
                  random_state=2)

y = to_categorical(y)

# build model
model = Sequential()
model.add(Dense(50, input_dim=2, activation='relu'))
model.add(Dense(3, activation='softmax'))

model.compile(loss='categorical_crossentropy', 
              optimizer=SGD(learning_rate=0.1))

epochs = 100
start_epoch = 75

# define swa callback
swa = SWA(start_epoch=start_epoch, 
          lr_schedule='constant', 
          swa_lr=0.01, 
          verbose=1)

# train
model.fit(X, y, epochs=epochs, verbose=1, callbacks=[swa])

Output

Epoch 1/100
1000/1000 [==============================] - 1s 703us/step - loss: 0.7518
Epoch 2/100
1000/1000 [==============================] - 0s 47us/step - loss: 0.5997
...
Epoch 74/100
1000/1000 [==============================] - 0s 31us/step - loss: 0.3913
Epoch 75/100
Epoch 00075: starting stochastic weight averaging
1000/1000 [==============================] - 0s 202us/step - loss: 0.3907
Epoch 76/100
1000/1000 [==============================] - 0s 47us/step - loss: 0.3911
...
Epoch 99/100
1000/1000 [==============================] - 0s 31us/step - loss: 0.3910
Epoch 100/100
1000/1000 [==============================] - 0s 47us/step - loss: 0.3905

Epoch 00100: final model weights set to stochastic weight average

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-swa-0.0.4.tar.gz (3.1 kB view details)

Uploaded Source

File details

Details for the file keras-swa-0.0.4.tar.gz.

File metadata

  • Download URL: keras-swa-0.0.4.tar.gz
  • Upload date:
  • Size: 3.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.20.1 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.8

File hashes

Hashes for keras-swa-0.0.4.tar.gz
Algorithm Hash digest
SHA256 eff1cd79265961655f91bc48db373bda78e5ab551267770cf24db7531b61edc7
MD5 0165707d4ac6bf5dfc991b3d9a2c93cf
BLAKE2b-256 34641cd6879cb8d6ce597c7e54d1bd6a17c294ecacaaa45e37466962cf90ad59

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page