Skip to main content

Simple stochastic weight averaging callback for Keras.

Project description

Keras SWA - Stochastic Weight Averaging

PyPI version License

This is an implemention of SWA for Keras and TF-Keras.

Introduction

Stochastic weight averaging (SWA) is build upon the same principle as snapshot ensembling and fast geometric ensembling. The idea is that averaging select stages of training can lead to better models. Where as the two former methods average by sampling and ensembling models, SWA instead average weights. This has been shown to give comparable improvements confined into a single model.

Illustration

Paper

Installation

pip install keras-swa

SWA API

Keras callback object for SWA.

Arguments

start_epoch - Starting epoch for SWA.

lr_schedule - Learning rate schedule. 'manual' , 'constant' or 'cyclic'.

swa_lr - Learning rate used when averaging weights.

swa_lr2 - Upper bound of learning rate for the cyclic schedule.

swa_freq - Frequency of weight averagining. Used with cyclic schedules.

batch_size - Batch size model is being trained with (only when using batch normalization).

verbose - Verbosity mode, 0 or 1.

Batch Normalization

Last epoch will be a forward pass, i.e. have learning rate set to zero, for models with batch normalization. This is due to the fact that batch normalization uses the running mean and variance of it's preceding layer to make a normalization. SWA will offset this normalization by suddenly changing the weights in the end of training. Therefore, it is necessary for the last epoch to be used to reset and recalculate batch normalization running mean and variance for the updated weights. Batch normalization gamma and beta values are preserved.

When using manual schedule: The SWA callback will set learning rate to zero in the last epoch if batch normalization is used. This must not be undone by any external learning rate schedulers for SWA to work properly.

Learning Rate Schedules

The default schedule is 'manual', allowing the learning rate to be controlled by an external learning rate scheduler or the optimizer. Then SWA will only affect the final weights and the learning rate of the last epoch if batch normalization is used. The schedules for the two predefined, 'constant' or 'cyclic' can be observed below.

lr_schedules

Example

For Tensorflow Keras (with constant LR)

from sklearn.datasets import make_blobs
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import SGD

from swa.tfkeras import SWA

# make dataset
X, y = make_blobs(n_samples=1000, 
                  centers=3, 
                  n_features=2, 
                  cluster_std=2, 
                  random_state=2)

y = to_categorical(y)

# build model
model = Sequential()
model.add(Dense(50, input_dim=2, activation='relu'))
model.add(Dense(3, activation='softmax'))

model.compile(loss='categorical_crossentropy', 
              optimizer=SGD(lr=0.1))

epochs = 100
start_epoch = 75

# define swa callback
swa = SWA(start_epoch=start_epoch, 
          lr_schedule='constant', 
          swa_lr=0.01, 
          verbose=1)

# train
model.fit(X, y, epochs=epochs, verbose=1, callbacks=[swa])

Or for Keras (with Cyclic LR)

from sklearn.datasets import make_blobs
from keras.utils import to_categorical
from keras.models import Sequential
from keras.layers import Dense, BatchNormalization
from keras.optimizers import SGD

from swa.keras import SWA

# make dataset
X, y = make_blobs(n_samples=1000, 
                  centers=3, 
                  n_features=2, 
                  cluster_std=2, 
                  random_state=2)

y = to_categorical(y)

# build model
model = Sequential()
model.add(Dense(50, input_dim=2, activation='relu'))
model.add(BatchNormalization())
model.add(Dense(3, activation='softmax'))

model.compile(loss='categorical_crossentropy', 
              optimizer=SGD(learning_rate=0.1))

epochs = 100
start_epoch = 75

# define swa callback
swa = SWA(start_epoch=start_epoch, 
          lr_schedule='cyclic', 
          swa_lr=0.001,
          swa_lr2=0.003,
          swa_freq=3,
          batch_size=32, # needed when using batch norm
          verbose=1)

# train
model.fit(X, y, batch_size=32, epochs=epochs, verbose=1, callbacks=[swa])

Output

Model uses batch normalization. SWA will require last epoch to be a forward pass and will run with no learning rate
Epoch 1/100
1000/1000 [==============================] - 1s 547us/sample - loss: 0.5529
Epoch 2/100
1000/1000 [==============================] - 0s 160us/sample - loss: 0.4720
...
Epoch 74/100
1000/1000 [==============================] - 0s 160us/sample - loss: 0.4249

Epoch 00075: starting stochastic weight averaging
Epoch 75/100
1000/1000 [==============================] - 0s 164us/sample - loss: 0.4357
Epoch 76/100
1000/1000 [==============================] - 0s 165us/sample - loss: 0.4209
...
Epoch 99/100
1000/1000 [==============================] - 0s 167us/sample - loss: 0.4263

Epoch 00100: final model weights set to stochastic weight average

Epoch 00100: reinitializing batch normalization layers

Epoch 00100: running forward pass to adjust batch normalization
Epoch 100/100
1000/1000 [==============================] - 0s 166us/sample - loss: 0.4408

Collaborators

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-swa-0.1.6.tar.gz (7.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

keras_swa-0.1.6-py3-none-any.whl (7.6 kB view details)

Uploaded Python 3

File details

Details for the file keras-swa-0.1.6.tar.gz.

File metadata

  • Download URL: keras-swa-0.1.6.tar.gz
  • Upload date:
  • Size: 7.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.8.5

File hashes

Hashes for keras-swa-0.1.6.tar.gz
Algorithm Hash digest
SHA256 92fb36be8522a62da985d30ff21b3fbb16ebe1fcefe440bc39627aa33d011e36
MD5 2ebd447ddaa7e09332ed1a0d19bddb4a
BLAKE2b-256 90d3d8a9dbdc3a6c71f896f1ab8c9cfc76a714acb461052cbead15a5a69836cc

See more details on using hashes here.

File details

Details for the file keras_swa-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: keras_swa-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 7.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.8.5

File hashes

Hashes for keras_swa-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 435ed50adbd70b48cc6dd0a234a388d751a3d16c0adabb495258ec5701b973b8
MD5 eed99f0a3f6e50c903591e8977b83370
BLAKE2b-256 07730256256dae8206e239e031a15a61e09f67412e4c176eed8b74c3b2e9cbfe

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page