Skip to main content

Keras wrapper that autosaves what ModelCheckpoint cannot.

Project description

https://travis-ci.com/dorukkarinca/keras-buoy.svg?branch=master

Keras wrapper that autosaves and auto-recovers not just the model weights but also the last epoch number and training history metrics.

See it in action in this Colab notebook!

pip install keras-buoy

Description

When training is interrupted and you rerun the whole code, it recovers the model weights and the epoch counter to the last saved values. Then it resumes training as if nothing happened. At the end, the Keras History.history dictionaries are combined so that the training history looks like one single training run.

Example

>>> from tensorflow import keras
>>> from keras_buoy.models import ResumableModel

>>> model = keras.Sequential()
...
>>> resumable_model = ResumableModel(model,
                                     save_every_epochs=4,
                                     custom_objects=None,
                                     to_path='/path/to/save/model_weights.h5')
>>> history = resumable_model.fit(x=x_train,
                                  y=y_train,
                                  validation_split=0.1,
                                  batch_size=256,
                                  verbose=2,
                                  epochs=15)

Recovered model from kerascheckpoint.h5 at epoch 8.

Epoch 9/15
1125/1125 - 5s - loss: 0.4790 - top_k_categorical_accuracy: 0.9698 - val_loss: 1.1075 - val_top_k_categorical_accuracy: 0.9206
Epoch 10/15
1125/1125 - 5s - loss: 0.4758 - top_k_categorical_accuracy: 0.9701 - val_loss: 1.1119 - val_top_k_categorical_accuracy: 0.9214
Epoch 11/15
1125/1125 - 5s - loss: 0.4753 - top_k_categorical_accuracy: 0.9702 - val_loss: 1.1000 - val_top_k_categorical_accuracy: 0.9215
Epoch 12/15

Try it out yourself in this Colab notebook.

Docs

keras_buoy.models.ResumableModel

Creates a resumable model.

Parameters:

Parameter name

Description

model (tf.keras.Model)

The instance of tf.keras.Model which you want to make resumable.

save_every_epochs (int)

Specifies how often to save the model, history, and epoch counter. In case of a crash, recovery will happen from the last saved epoch multiple.

custom_objects (dict)

At recovery time, this is passed into tf.keras.models.load_model(...) exactly as shown in Tensorflow docs so you can load your model with a custom loss for example.

to_path (str)

Specifies the path where the model weights will be saved. If it ends with .h5, then it saves in the Keras H5 format instead of the default TensorFlow SavedModel format.

If to_path is mymodel.h5, then there will be mymodel_epoch_num.pkl and mymodel_history.pkl in the same directory as mymodel.h5, which hold backups for the epoch counter and the history dict, respectively.

Returns:

A ResumableModel instance. You can call .fit(...) on it.




keras.buoy.models.ResumableModel.fit

Fits a resumable model.

Parameters:

The accepted parameters are the same as tf.Keras.model.fit(...) except you cannot specify initial_epoch.

Returns:

history (dict): The history dict of the Keras History object. Note that it does not return the Keras.History object itself, just the dict.

Note

This project has been set up using PyScaffold 3.2.3. For details and usage information on PyScaffold see https://pyscaffold.org/.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-buoy-0.1.4.1.tar.gz (19.0 kB view details)

Uploaded Source

File details

Details for the file keras-buoy-0.1.4.1.tar.gz.

File metadata

  • Download URL: keras-buoy-0.1.4.1.tar.gz
  • Upload date:
  • Size: 19.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.6.0.post20200814 requests-toolbelt/0.9.1 tqdm/4.49.0 CPython/3.8.5

File hashes

Hashes for keras-buoy-0.1.4.1.tar.gz
Algorithm Hash digest
SHA256 3c841ce50f8684b8fbc2f3f4bff06e885e45698c853674570c13e471f43ece0b
MD5 7d8ec056def52085876080d4b7b51b60
BLAKE2b-256 0c22ea1c8daacc58f1f58e99a0a46354c5fb56e767f16101f6eed90d8e903b55

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page