Skip to main content

Support PyTorch checkpoints

Project description

pytorch-checkpoint

PyPI version

This package supports saving and loading PyTorch training checkpoints. It is useful when trying the resume model training from a previous step, and can become handy when working with spot instances or when trying to reproduce results.

A model is saved not only with its weights, as one might do for later inference, but the entire state of the model, including the optimizer state and parameters.

In addition, it allows saving metrics and other values generated while training, such as accuracy and loss values. This makes it possible to recreate the learning curves from past values and continue to update them as training proceed.


Prerequisites

Developed with Python 3.7.3, but should be compatible with previous Python version.

pip install torch==1.1.0 torchvision==0.3.0

Installation

pip install pytorchcheckpoint

Usage

from pytorchcheckpoint.checkpoint import CheckpointHandler
checkpoint_handler = CheckpointHandler()

Storing a general value

checkpoint_handler.store_var(var_name='num_of_classes', value=1000)

Reading a general value

num_of_classes = checkpoint_handler.get_var(var_name='num_of_classes')

Storing values and metrics for each epoch/iteration. For example, the loss value:

checkpoint_handler.store_running_var(var_name='loss', iteration=0, value=1.0)
checkpoint_handler.store_running_var(var_name='loss', iteration=1, value=0.9)
checkpoint_handler.store_running_var(var_name='loss', iteration=2, value=0.8)

Reading stored values for epoch/iteration

loss = checkpoint_handler.get_running_var(var_name='loss', iteration=0)

Storing values and metrics per set: train/valid/test for each epoch/iteration. For example, the top1 value of the train and valid sets:

checkpoint_handler.store_running_var_with_header(header='train', var_name='top1', iteration=0, value=80)
checkpoint_handler.store_running_var_with_header(header='train', var_name='top1', iteration=1, value=85)
checkpoint_handler.store_running_var_with_header(header='train', var_name='top1', iteration=2, value=90)
checkpoint_handler.store_running_var_with_header(header='train', var_name='top1', iteration=3, value=91)

checkpoint_handler.store_running_var_with_header(header='valid', var_name='top1', iteration=0, value=70)
checkpoint_handler.store_running_var_with_header(header='valid', var_name='top1', iteration=1, value=75)
checkpoint_handler.store_running_var_with_header(header='valid', var_name='top1', iteration=2, value=80)
checkpoint_handler.store_running_var_with_header(header='valid', var_name='top1', iteration=3, value=85)

Reading stored values per set: train/valid/test for epoch/iteration

loss = checkpoint_handler.get_running_var_with_header(header='train', var_name='loss', iteration=0)

Save checkpoint:

import torchvision.models as models
from torch import optim
checkpoint_handler.store_running_var(var_name='loss', iteration=0, value=1.0)
model = models.resnet18()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
path2save = '/tmp'
checkpoint_path = checkpoint_handler.generate_checkpoint_path(path2save=path2save)
checkpoint_handler.save_checkpoint(checkpoint_path=checkpoint_path, iteration=25, model=model, optimizer=optimizer)

Load checkpoint:

checkpoint_path = '<checkpoint_path>'
checkpoint_handler = checkpoint_handler.load_checkpoint(checkpoint_path)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorchcheckpoint-0.0.5.tar.gz (3.8 kB view details)

Uploaded Source

Built Distribution

pytorchcheckpoint-0.0.5-py3-none-any.whl (4.2 kB view details)

Uploaded Python 3

File details

Details for the file pytorchcheckpoint-0.0.5.tar.gz.

File metadata

  • Download URL: pytorchcheckpoint-0.0.5.tar.gz
  • Upload date:
  • Size: 3.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.7.3

File hashes

Hashes for pytorchcheckpoint-0.0.5.tar.gz
Algorithm Hash digest
SHA256 9f5f8736940475d06513bfbd5b9e7724cdaa647a79cb3dc431f3d5d06606064f
MD5 4429276cfa30ac48cb4ec5e429bd5782
BLAKE2b-256 b5dc5aa0f499fc0630f36779a847246c0bfcddd6585bddbe0f1107025c558817

See more details on using hashes here.

File details

Details for the file pytorchcheckpoint-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: pytorchcheckpoint-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 4.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.7.3

File hashes

Hashes for pytorchcheckpoint-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 193836b93c61550557f7ebc48bd7fccdd8b74e9268f1f9d8c24abd3bf5afb2ab
MD5 443bd657b9cede13e92513aed641bd6f
BLAKE2b-256 1e27d5620c35329772919fdf7e8d30f99f2cc24adc915c2d7d7f4e1dee6ba2af

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page