Skip to main content

Keras wrapper that autosaves what ModelCheckpoint cannot.

Project description

https://travis-ci.com/dorukkarinca/keras-buoy.svg?branch=master

Keras wrapper that autosaves and auto-recovers not just the model weights but also the last epoch number and training history metrics.

pip install keras-buoy
>>> resumableModel = ResumableModel(model, save_every_epochs=4, to_path='kerascheckpoint.h5')
>>> history = resumableModel.fit(x = x_train, y = y_train, validation_split=0.1, batch_size = 256, verbose=2, epochs=15)

Recovered model from kerascheckpoint.h5 at epoch 8.

Epoch 9/15
1125/1125 - 5s - loss: 0.4790 - top_k_categorical_accuracy: 0.9698 - val_loss: 1.1075 - val_top_k_categorical_accuracy: 0.9206
Epoch 10/15
1125/1125 - 5s - loss: 0.4758 - top_k_categorical_accuracy: 0.9701 - val_loss: 1.1119 - val_top_k_categorical_accuracy: 0.9214
Epoch 11/15
1125/1125 - 5s - loss: 0.4753 - top_k_categorical_accuracy: 0.9702 - val_loss: 1.1000 - val_top_k_categorical_accuracy: 0.9215
Epoch 12/15
...

Description

When training is interrupted due to a crash/accidental Ctrl+C and you rerun the whole code, it recovers the model weights and the epoch counter to the last saved values. Then it resumes training as if nothing happened. At the end, the Keras History.history dictionaries are combined so that the training history looks like one single training run.

Example

from tensorflow import keras
from keras_buoy.models import ResumableModel

model = keras.Sequential()
...
resumable_model = ResumableModel(model, save_every_epochs = 4, custom_objects=None, to_path='/path/to/save/model_weights.h5')
history = resumable_model.fit(x = x_train, y = y_train, validation_split = 0.1, batch_size = 256, verbose = 2, epochs = 12)

Usage

custom_objects (dict) is passed into tf.keras.models.load_model(...) so you can load your model with a custom loss for example.

save_every_epochs (int) will save the model, history, and epoch counter every so often. In case of a crash, recovery will happen from the last saved epoch multiple.

to_path (str) is where the model weights will be saved, and must have the .h5 extension.

resumable_model.fit(...) is the same as Keras’ model.fit(...).

It returns history which is the history dict of the Keras History object. Note that it does not return the Keras.History object itself, just the dict.

If to_path is mymodel.h5, then there will be mymodel_epoch_num.pkl and mymodel_history.pkl in the same directory as mymodel.h5, which hold backups for the epoch counter and the history dict, respectively.

Note

This project has been set up using PyScaffold 3.2.3. For details and usage information on PyScaffold see https://pyscaffold.org/.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-buoy-0.1.3.tar.gz (17.9 kB view details)

Uploaded Source

File details

Details for the file keras-buoy-0.1.3.tar.gz.

File metadata

  • Download URL: keras-buoy-0.1.3.tar.gz
  • Upload date:
  • Size: 17.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.1.0.post20200704 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.8.3

File hashes

Hashes for keras-buoy-0.1.3.tar.gz
Algorithm Hash digest
SHA256 4aa50f30c313af5a62bcdbcb8eb43242651f5263a583736cec125890b1a90d01
MD5 1e829a536efa4f4622b8a27dff73680f
BLAKE2b-256 0a7f11d9b6dba8d1e4fbd80ee839ce39b6eb9aa11a46ce7ee10c1ebe41fce138

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page