Skip to main content

A performant, memory-efficient checkpointing library for PyTorch applications, designed with large, complex distributed workloads in mind.

Project description

TorchSnapshot (Beta Release)

build status pypi version conda version pypi nightly version codecov bsd license

A performant, memory-efficient checkpointing library for PyTorch applications, designed with large, complex distributed workloads in mind.

Install

Requires Python >= 3.7 and PyTorch >= 1.12

From pip:

# Stable
pip install torchsnapshot
# Or, using conda
conda install -c conda-forge torchsnapshot

# Nightly
pip install --pre torchsnapshot-nightly

From source:

git clone https://github.com/pytorch/torchsnapshot
cd torchsnapshot
pip install -r requirements.txt
python setup.py install

Why TorchSnapshot

Performance

  • TorchSnapshot provides a fast checkpointing implementation employing various optimizations, including zero-copy serialization for most tensor types, overlapped device-to-host copy and storage I/O, parallelized storage I/O.
  • TorchSnapshot greatly speeds up checkpointing for DistributedDataParallel workloads by distributing the write load across all ranks (benchmark).
  • When host memory is abundant, TorchSnapshot allows training to resume before all storage I/O completes, reducing the time blocked by checkpoint saving.

Memory Usage

  • TorchSnapshot's memory usage adapts to the host's available resources, greatly reducing the chance of out-of-memory issues when saving and loading checkpoints.
  • TorchSnapshot supports efficient random access to individual objects within a snapshot, even when the snapshot is stored in a cloud object storage.

Usability

  • Simple APIs that are consistent between distributed and non-distributed workloads.
  • Out of the box integration with commonly used cloud object storage systems.
  • Automatic resharding (elasticity) on world size change for supported workloads (more details).

Security

  • Secure tensor serialization without pickle dependency [WIP].

Getting Started

from torchsnapshot import Snapshot

# Taking a snapshot
app_state = {"model": model, "optimizer": optimizer}
snapshot = Snapshot.take(path="/path/to/snapshot", app_state=app_state)

# Restoring from a snapshot
snapshot.restore(app_state=app_state)

See the documentation for more details.

License

torchsnapshot is BSD licensed, as found in the LICENSE file.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchsnapshot-nightly-2023.9.6.tar.gz (53.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchsnapshot_nightly-2023.9.6-py3-none-any.whl (70.5 kB view details)

Uploaded Python 3

File details

Details for the file torchsnapshot-nightly-2023.9.6.tar.gz.

File metadata

File hashes

Hashes for torchsnapshot-nightly-2023.9.6.tar.gz
Algorithm Hash digest
SHA256 d953c19843ac6af709ebcf0f58f1a561318ef3d69a81a70ea067eabe82299727
MD5 792c55837e6644a24958f2933acb10e0
BLAKE2b-256 0c5c4f9f99fe14c6451e787d9d0e0e7a38a45eacccf58296c38578ab7e9caffa

See more details on using hashes here.

File details

Details for the file torchsnapshot_nightly-2023.9.6-py3-none-any.whl.

File metadata

File hashes

Hashes for torchsnapshot_nightly-2023.9.6-py3-none-any.whl
Algorithm Hash digest
SHA256 f08cfb9b29a4fa8540c9c74915d8cb2aaf423d207b7f2a9d43a2708d8810153f
MD5 b89bfdf8474075706d029b1d7fecee59
BLAKE2b-256 ce590cbfbdd80878b25496d829d62f77329b0387e2f97c354bfdd8c90eef4cf8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page