Skip to main content

An adapter callback for Keras ModelCheckpoint that allows checkpointing an alternate model (often submodel of a multi-GPU model).

Project description

alt-model-checkpoint

An adapter callback for Keras ModelCheckpoint that allows checkpointing an alternate model (often submodel of a multi-GPU model).

Installation

pip install alt-model-checkpoint

Usage

You must provide your own Keras or Tensorflow installation. See Pipfile for preferred versions.

If using the Keras bundled in Tensorflow:

from alt_model_checkpoint.tensorflow import AltModelCheckpoint

If using Keras standalone:

from alt_model_checkpoint.keras import AltModelCheckpoint

Common usage involving multi-GPU models built with Keras multi_gpu_model():

from alt_model_checkpoint.keras import AltModelCheckpoint
from keras.models import Model
from keras.utils import multi_gpu_model

def compile_model(m):
    """Implement with your model compile logic; both base and GPU models should be compiled identically"""
    m.compile(...)

base_model = Model(...)
gpu_model = multi_gpu_model(base_model)
compile_model(base_model)
compile_model(gpu_model)

gpu_model.fit(..., callbacks=[
    AltModelCheckpoint('save/path/for/model.hdf5', base_model)
])

Constructor args

filepath

Model save file path; see underlying ModelCheckpoint docs for details.

alternate_model

Keras model to save instead of the default. This is used especially when training multi-gpu models built with Keras multi_gpu_model(). In that case, you would pass the original "template model" to be saved each checkpoint.

inherit_optimizer

If TRUE (default), saves the optimizer of the base model (e.g. a multi-gpu model) with the alternate model. This is necessary if you want to be able to resume training on a saved alternate model. If FALSE, the alternate model's optimizer will be saved as-is.

*args, **kwargs

These are passed as-is to the underlying ModelCheckpoint constructor.

Dev environment setup

  1. Install pipenv.
  2. Run make test (runs make test-build automatically to ensure deps)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

alt-model-checkpoint-2.0.3.tar.gz (3.7 kB view hashes)

Uploaded source

Built Distribution

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Huawei Huawei PSF Sponsor Microsoft Microsoft PSF Sponsor NVIDIA NVIDIA PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page