Skip to main content

Chunk and checkpoint memory optimisation

Project description

Chunk and Checkpoint Memory Optimisation

CI

Reduce peak memory when training models in PyTorch which require batched operations internally, such as Swin Transformers.

TLDR:

from chunkcheck import chunk_and_checkpoint

...

# There is a really large batch size along dimension 0. `chunk_and_checkpoint`
# substantially reduces peak memory usage. Adjust `chunk_size` to achieve your
# preferred time vs memory tradeoff.
y = chunk_and_checkpoint(f, x1, x2, ..., chunk_size=4, batch_dim=0)

...

Installation

pip install chunkcheck

Usage

chunkcheck exports one function: chunk_and_checkpoint. It can be fruitfully used to reduce the peak memory requirement of a programme written using PyTorch when the following hold:

  • You have one or more input torch.Tensors (x1, x2, ...) whose first dimension is a "batch" dimension of equal size.
  • You wish to compute f(x1, x2, ...), where f applies the same operation to each "batch" in (x1, x2, ...).
  • The memory required during intermediate computations in f is large compared to the memory required to store (x1, x2, ...) and the output of f(x1, x2, ...). A canonical example of this kind of function is an MLP with large hidden dimension(s).

Instead of calling f(x1, x2, ...), call chunk_and_checkpoint(f, x1, x2, ..., chunk_size=chunk_size), for some int chunk_size. Doing this should substantially reduce peak memory, and increase the computation time by only a small amount for a well-chosen chunk_size. chunk_and_checkpoint will reduce peak memory further than torch.utils.checkpoint.checkpoint ("activation checkpointing"), the exact amount depends on chunk_size.

See the docstring for chunk_and_checkpoint for more information. For a more detailed explanation of why this works, and some usage case studies, see our note on arXiv (TODO: write this and link to it).

Development

Clone the repo and cd into the repository. Then create a virtual environment, enter it, and install all dependencies:

uv venv
source .venv/bin/activate
uv sync

Running the tests:

pytest -v

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

chunkcheck-0.1.1.tar.gz (3.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

chunkcheck-0.1.1-py3-none-any.whl (3.4 kB view details)

Uploaded Python 3

File details

Details for the file chunkcheck-0.1.1.tar.gz.

File metadata

  • Download URL: chunkcheck-0.1.1.tar.gz
  • Upload date:
  • Size: 3.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.10 {"installer":{"name":"uv","version":"0.9.10"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for chunkcheck-0.1.1.tar.gz
Algorithm Hash digest
SHA256 e72289972e10482df84f4e9c3191fbedca27a566da9e934abd660109f0985851
MD5 a8f4c46cc3f017ed9215a647842a27cb
BLAKE2b-256 095de4a4d060b2374a3804262f19c8b86d44b012064258fdf209b81b47f56cf6

See more details on using hashes here.

File details

Details for the file chunkcheck-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: chunkcheck-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 3.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.10 {"installer":{"name":"uv","version":"0.9.10"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for chunkcheck-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b32dc9614bdfb5298bc98d5f8814e0cc8790417c5d99bb07499e7a4ff718a7ef
MD5 3b9ccf703fa5d1d9c8cf7132f8daee4f
BLAKE2b-256 1db763fbb29a575f4b80d60e325cd9f3cbef6acde3340c1123a8a50b3d77b59c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page