Chunk and checkpoint memory optimisation
Project description
Chunk and Checkpoint Memory Optimisation
Reduce peak memory when training models in PyTorch which require batched operations internally, such as Swin Transformers.
TLDR:
from chunkcheck import chunk_and_checkpoint
...
# There is a really large batch size along dimension 0. `chunk_and_checkpoint`
# substantially reduces peak memory usage. Adjust `chunk_size` to achieve your
# preferred time vs memory tradeoff.
y = chunk_and_checkpoint(f, x1, x2, ..., chunk_size=4, batch_dim=0)
...
Installation
pip install chunkcheck
Usage
chunkcheck exports one function: chunk_and_checkpoint.
It can be fruitfully used to reduce the peak memory requirement of a programme written using PyTorch when the following hold:
- You have one or more input
torch.Tensors (x1,x2, ...) whose first dimension is a "batch" dimension of equal size. - You wish to compute
f(x1, x2, ...), wherefapplies the same operation to each "batch" in (x1,x2, ...). - The memory required during intermediate computations in
fis large compared to the memory required to store (x1,x2, ...) and the output off(x1, x2, ...). A canonical example of this kind of function is an MLP with large hidden dimension(s).
Instead of calling f(x1, x2, ...), call chunk_and_checkpoint(f, x1, x2, ..., chunk_size=chunk_size), for some int chunk_size.
Doing this should substantially reduce peak memory, and increase the computation time by only a small amount for a well-chosen chunk_size.
chunk_and_checkpoint will reduce peak memory further than torch.utils.checkpoint.checkpoint ("activation checkpointing"), the exact amount depends on chunk_size.
See the docstring for chunk_and_checkpoint for more information.
For a more detailed explanation of why this works, and some usage case studies, see our note on arXiv (TODO: write this and link to it).
Development
Clone the repo and cd into the repository.
Then create a virtual environment, enter it, and install all dependencies:
uv venv
source .venv/bin/activate
uv sync
Running the tests:
pytest -v
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file chunkcheck-0.1.1.tar.gz.
File metadata
- Download URL: chunkcheck-0.1.1.tar.gz
- Upload date:
- Size: 3.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.9.10 {"installer":{"name":"uv","version":"0.9.10"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e72289972e10482df84f4e9c3191fbedca27a566da9e934abd660109f0985851
|
|
| MD5 |
a8f4c46cc3f017ed9215a647842a27cb
|
|
| BLAKE2b-256 |
095de4a4d060b2374a3804262f19c8b86d44b012064258fdf209b81b47f56cf6
|
File details
Details for the file chunkcheck-0.1.1-py3-none-any.whl.
File metadata
- Download URL: chunkcheck-0.1.1-py3-none-any.whl
- Upload date:
- Size: 3.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.9.10 {"installer":{"name":"uv","version":"0.9.10"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b32dc9614bdfb5298bc98d5f8814e0cc8790417c5d99bb07499e7a4ff718a7ef
|
|
| MD5 |
3b9ccf703fa5d1d9c8cf7132f8daee4f
|
|
| BLAKE2b-256 |
1db763fbb29a575f4b80d60e325cd9f3cbef6acde3340c1123a8a50b3d77b59c
|