Skip to main content
Join the official 2019 Python Developers SurveyStart the survey!

Support PyTorch checkpoints

Project description

pytorch-checkpoint

PyPI version

This package supports saving and loading PyTorch training checkpoints. It is useful when trying the resume model training from a previous step, and can become handy when working with spot instances or when trying to reproduce results.

A model is saved not only with its weights, as one might do for later inference, but the entire state of the model, including the optimizer state and parameters.

In addition, it allows saving metrics and other values generated while training, such as accuracy and loss values. This makes it possible to recreate the learning curves from past values and continue to update them as training proceed.


Prerequisites

Developed with Python 3.7.3, but should be compatible with previous Python version.

pip install torch==1.1.0 torchvision==0.3.0

Installation

pip install pytorchcheckpoint

Usage

from pytorchcheckpoint.checkpoint import CheckpointHandler
checkpoint_handler = CheckpointHandler()

Storing a general value

checkpoint_handler.store_var(var_name='num_of_classes', value=1000)

Reading a general value

num_of_classes = checkpoint_handler.get_var(var_name='num_of_classes')

Storing values and metrics for each epoch/iteration. For example, the loss value:

checkpoint_handler.store_running_var(var_name='loss', iteration=0, value=1.0)
checkpoint_handler.store_running_var(var_name='loss', iteration=1, value=0.9)
checkpoint_handler.store_running_var(var_name='loss', iteration=2, value=0.8)

Reading stored values for epoch/iteration

loss = checkpoint_handler.get_running_var(var_name='loss', iteration=0)

Storing values and metrics per set: train/valid/test for each epoch/iteration. For example, the top1 value of the train and valid sets:

checkpoint_handler.store_running_var_with_header(header='train', var_name='top1', iteration=0, value=80)
checkpoint_handler.store_running_var_with_header(header='train', var_name='top1', iteration=1, value=85)
checkpoint_handler.store_running_var_with_header(header='train', var_name='top1', iteration=2, value=90)
checkpoint_handler.store_running_var_with_header(header='train', var_name='top1', iteration=3, value=91)

checkpoint_handler.store_running_var_with_header(header='valid', var_name='top1', iteration=0, value=70)
checkpoint_handler.store_running_var_with_header(header='valid', var_name='top1', iteration=1, value=75)
checkpoint_handler.store_running_var_with_header(header='valid', var_name='top1', iteration=2, value=80)
checkpoint_handler.store_running_var_with_header(header='valid', var_name='top1', iteration=3, value=85)

Reading stored values per set: train/valid/test for epoch/iteration

loss = checkpoint_handler.get_running_var_with_header(header='train', var_name='loss', iteration=0)

Save checkpoint:

import torchvision.models as models
from torch import optim
checkpoint_handler.store_running_var(var_name='loss', iteration=0, value=1.0)
model = models.resnet18()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
path2save = '/tmp'
checkpoint_path = checkpoint_handler.generate_checkpoint_path(path2save=path2save)
checkpoint_handler.save_checkpoint(checkpoint_path=checkpoint_path, iteration=25, model=model, optimizer=optimizer)

Load checkpoint:

checkpoint_path = '<checkpoint_path>'
checkpoint_handler = checkpoint_handler.load_checkpoint(checkpoint_path)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for pytorchcheckpoint, version 0.0.5
Filename, size File type Python version Upload date Hashes
Filename, size pytorchcheckpoint-0.0.5-py3-none-any.whl (4.2 kB) File type Wheel Python version py3 Upload date Hashes View hashes
Filename, size pytorchcheckpoint-0.0.5.tar.gz (3.8 kB) File type Source Python version None Upload date Hashes View hashes

Supported by

Elastic Elastic Search Pingdom Pingdom Monitoring Google Google BigQuery Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN SignalFx SignalFx Supporter DigiCert DigiCert EV certificate StatusPage StatusPage Status page