Skip to main content

PyTorch Model State Save & Load

Project description

torch-model-state Build Status codecov PyPI version

PyTorch Model State Save & Load.

Installation

Need Python 3.6+.

pip install torch-model-state

Usage

Python:

import box
import torch_model_state
from torch.optim import SGD

config = {
  'type': 'MobileNetV2'  # need install torch-basic-models
}
model = box.factory(config=config, tag='model')
optimizer = SGD(model.parameters(), lr=0.1)

state = torch_model_state.to_state(model=model, config=config, optimizers=[optimizer])
torch_model_state.save_state_file(state=state, file_path='checkpoint.sf')

state = torch_model_state.load_state_file(file_path='checkpoint.sf', device='cpu')
torch_model_state.from_state(state, model, [optimizer], device='cpu')

Load from State File (.sf) directly:

import torch_model_state

model = torch_model_state.load_model_from_state(file_path='checkpoint.sf', device='cpu')

CLI:

# show help
torch-model-state -h
#> usage: torch-model-state [-h] [--load_model] [--extra_import EXTRA_IMPORT]
#>                          [--device DEVICE]
#>                          state_file
#>
#> Viewer of PyTorch State File [.sf]
#>
#> positional arguments:
#>   state_file            path of PyTorch state file
#>
#> optional arguments:
#>   -h, --help            show this help message and exit
#>   --load_model          load model and show
#>   --extra_import EXTRA_IMPORT
#>                         import extra models
#>   --device DEVICE       load device, cpu in default

# view basic info of state file
torch-model-state checkpoint.sf
#> {
#>   "config": {
#>     "type": "MobileNetV2"
#>   },
#>   "info": null,
#>   "timestamp": "2019-04-27 22:42:55.345000"
#> }

# view & load Model
torch-model-state checkpoint.sf --load_model
#> {
#>   "config": {
#>     "type": "MobileNetV2"
#>   },
#>   "info": null,
#>   "timestamp": "2019-04-27 22:42:55.345000"
#> }
#> MobileNetV2(
#>   (blocks): Sequential(
#>     (0): Sequential(
#>       (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
#>       (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#>       (2): InplaceReLU6(inplace)
#>     )
#>   ...

# export to ONNX
torch-model-state checkpoint.sf --export_onnx checkpoint.onnx

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for torch-model-state, version 0.0.9
Filename, size File type Python version Upload date Hashes
Filename, size torch-model-state-0.0.9.tar.gz (5.6 kB) File type Source Python version None Upload date Hashes View

Supported by

Pingdom Pingdom Monitoring Google Google Object Storage and Download Analytics Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page