Skip to main content

Train TensorFlow Keras models with Cosine Annealing and save an ensemble of models with no additional computational expense.

Project description

Build

PyPi

Downloads

License

Train TensorFlow Keras models with cosine annealing and save an ensemble of models with no additional computational expense.


snapshot_ensemble

Ensembles of machine learning models have empirically demonstrated

state-of-the-art results in many regression and classification tasks.

Deep neural networks are popular models given their flexibility and

theoretical properties, but ensembling several independent neural networks

is often impractical due to the computational expense.

Huang et al. (2017) proposes the simple idea of Snapshot Ensembling, where

a single neural network is trained via cyclic learning rate schedules such as

cosine annealing (Loshchilov and Hutter, 2017). At the end of each annealing cycle,

the model parameters are saved and thus we obtain an ensemble of trained neural

networks at the cost of training a single one.

Conceptually, we may think of this as letting the neural network quickly converge

by using a decaying learning rate, and then saving the model at several

local minima of the loss surface. We may then used the saved models as part of

an ensemble for prediction or inference.

This simple library is an implementation of their ideas as a TensorFlow 2 Keras Callback

to be used during training.

Documentation

Getting Started

Installation

pip install snapshot_ensemble

Dependencies:

# Required

python >= 3.6

numpy

tensorflow >= 2.0



# Suggested

matplotlib

Usage

from snapshot_ensemble import SnapshotEnsembleCallback



model = # Compiled TensorFlow 2 Keras model



# Train the Keras model with Cosine Annealing + Snapshot Ensembling

snapshotCB = SnapshotEnsembleCallback()

model.fit(*args,

          callbacks = [ snapshotCB ]

        )



# Snapshotted models are then automatically saved (default: `Ensemble/`)

# and may be loaded in for ensembled predictions or inference

Dynamic Learning Rate Schedule

The learning rate schedule inside SnapshotEnsembleCallback takes the following parameters:

-`cycle_length` : Initial number of epochs per cycle  

-`cycle_length_multiplier` : Multiplier on number of epochs per cycle  

-`lr_init` : Initial maximum learning rate  

-`lr_min` : Initial minimum learning rate  

-`lr_multiplier` : Multiplier on learning rate per cycle  

The cycle_length, lr_init, and lr_min parameters control the initial length and learning rate bounds of each cycle.

The *_multiplier parameters allow for dynamically adjusting the length and/or learning rate bounds as training

progresses. It is very likely that the default parameters are suboptimal for your task, so these hyperparameters

will need to be tuned. There is a helper function VisualizeLR() to visualize the learning rate schedule.

<img src="assets/LR0.png" width="32%" />

<img src="assets/LR1.png" width="32%" />

<img src="assets/LR2.png" width="32%" />

<em>

(Left) Standard Cosine Annealing (Middle) Dynamic length (Right) Dynamic length and learning rate bounds

</em>

Example

For a full example, see this

notebook.

References

Huang, G., Li, Y., & Pleiss, G. (2017). Snapshot Ensembles: Train 1, Get M for Free.

International Conference on Learning Representations. https://doi.org/https://doi.org/10.48550/arXiv.1704.00109

Loshchilov, I., & Hutter, F. (2017). SGDR: Stochastic Gradient Descent with Warm Restarts.

International Conference on Learning Representations. https://doi.org/https://doi.org/10.48550/arXiv.1608.03983

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

snapshot_ensemble-1.0.0.tar.gz (6.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

snapshot_ensemble-1.0.0-py3-none-any.whl (7.4 kB view details)

Uploaded Python 3

File details

Details for the file snapshot_ensemble-1.0.0.tar.gz.

File metadata

  • Download URL: snapshot_ensemble-1.0.0.tar.gz
  • Upload date:
  • Size: 6.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.25.1 requests-toolbelt/0.9.1 urllib3/1.26.9 tqdm/4.43.0 importlib-metadata/4.8.1 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.4 CPython/3.6.10

File hashes

Hashes for snapshot_ensemble-1.0.0.tar.gz
Algorithm Hash digest
SHA256 a3e9b28a48760bb255314c5c3b2d259d7bb35ae9b01237046b036fec92889b2d
MD5 c5f21c35eb26ea1fce12bd82456054bf
BLAKE2b-256 b89f4aea753cd0633be1c70e0b1370d7032b71693b9913363f1712cab30fc1b5

See more details on using hashes here.

File details

Details for the file snapshot_ensemble-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: snapshot_ensemble-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 7.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.25.1 requests-toolbelt/0.9.1 urllib3/1.26.9 tqdm/4.43.0 importlib-metadata/4.8.1 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.4 CPython/3.6.10

File hashes

Hashes for snapshot_ensemble-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f5cc57d2f35d38a703e4563d4b11be5c78148c6fd8bb2ddb7763d256bc44b74f
MD5 6a582967bb921348ef13124eb6a33beb
BLAKE2b-256 0a47e70e47ad83dcffbea0ab9ea93b184d26bce54bba6eaf505489ebbc2dd9d2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page