Optimizer state virtualization and compression for PyTorch
Project description
torch-optstate
Optimizer State Virtualization for PyTorch
torch-optstate wraps existing PyTorch optimizers (Adam/AdamW/SGD) to virtualize their state. It compresses and offloads optimizer state when not in use, then materializes it on-the-fly during step(), with chunked execution to keep peaks low. New defaults make it plug-and-play on CPU and CUDA.
Why use this?
- Optimizer state often costs 2–3× model params. Saving there unlocks larger batches/models.
- Compress and offload momentum/variance to CPU while keeping training unchanged (default minimizes VRAM).
- Chunked step avoids double-residency spikes; pinned CPU offload keeps VRAM low.
- Policy-driven: choose FP32/FP16/INT8 per-state with warmup or adaptive triggers.
What’s new (plug-and-play defaults)
- Auto-chunked step: always chunked; default chunk ≤8, first chunk = 1 to tame peaks.
- CUDA-aware pinned offload: compressed state is pinned on CPU automatically when params are on CUDA.
auto_wraphelper: one-call wrapping with the defaults above.- Low-memory AdamW helper:
wrap_low_memory_adamwdefaults to tiny first chunk + auto pin. - Adaptive warmup policy (optional): switch to compression when loss plateaus.
- Decode scratch cache: reuses decode buffers to reduce per-step allocations.
- Chunk-only path: closures are not supported; keeps peak usage low.
- Small-tensor bypass: int8 compression skips tiny tensors by default (configurable via
WarmupPolicy). - CUDA path: compressed state is offloaded to CPU by default to minimize VRAM; set
device_resident=Trueto keep it on GPU. - Max-compression preset:
wrap_max_compression_adamwfor int8-all state with GPU-friendly chunking.
Installation
pip install torch-optstate
Optional (Linux): pip install torch-optstate[triton] to enable torch.compile acceleration.
(Research preview.)
Usage
1) Drop-in (recommended default)
import torch
from torch.optim import AdamW
import torch_optstate as topt
model = torch.nn.Linear(10, 1).to("cuda") # or cpu
opt = AdamW(model.parameters(), lr=1e-3)
# One call: auto chunking, tiny first chunk, auto pin if on CUDA
# Default policy after warmup: int8 momentum + fp32 variance, offloaded to CPU.
opt = topt.auto_wrap(opt)
2) Low peak memory AdamW preset
import torch_optstate as topt
opt = topt.wrap_low_memory_adamw(
model.parameters(),
variance_mode="int8", # or "fp16"/"fp32"
chunk_size=None, # auto small chunk
initial_chunk_size=None # defaults to 1
# pin_memory None -> auto on CUDA
# device_resident=True, # keep compressed state on GPU instead of CPU offload
)
3) Max compression (int8-all, GPU-friendly)
import torch_optstate as topt
opt = topt.wrap_max_compression_adamw(
model.parameters(),
chunk_size_on_cuda=256, # defaults to 256 if chunk_size is None
initial_chunk_size=1
# device_resident=True, # keep compressed state on GPU instead of CPU offload
)
4) Custom policies (int8 / FP16 / BF16)
from torch_optstate import wrap, WarmupPolicy, FP16Codec, FP32Codec, Int8MomentumCodec
policy = WarmupPolicy(
warmup_steps=100,
momentum_key="exp_avg",
variance_key="exp_avg_sq",
variance_codec=Int8MomentumCodec(), # int8 variance
# min_int8_elements=4096, # default: skip int8 for tiny tensors
# device_resident=False, # force CPU offload even on CUDA
)
opt = wrap(opt, policy=policy, chunk_size=8, initial_chunk_size=1)
5) GPU offload (pinned CPU) and chunking
- Offload is default (including after compression); pinning is automatic on CUDA (override with
pin_memory=False). - Set
device_resident=Trueif you want compressed state to stay on GPU instead. - Chunked step is always on; defaults are small to reduce VRAM overlap.
Example CLI (demo) for GPU:
poetry run python -m examples.finetune_demo \
--steps 10 \
--small_llm \
--compression_mode default \
--metrics_csv gpu_metrics.csv
Generates memory_comparison.png with GPU VRAM and CPU RAM traces.
Default benchmark (example)
From gpu_metrics.csv (10 steps, small_11m, default compression):
- Peak GPU allocated (
gpu_mem_mb): 2074.91 MB -> 1116.60 MB (-958.32 MB, -46.2%) - Peak GPU peak (
gpu_peak_mb): 4715.73 MB -> 3755.91 MB (-959.82 MB, -20.4%) - Peak CPU RSS (
cpu_mem_mb): 1698.05 MB -> 2582.95 MB (+884.90 MB, +52.1%) - Peak tensor state (
tensor_mem_mb): 1023.24 MB -> 640.89 MB (-382.35 MB, -37.4%)
These numbers reflect the expected trade-off: GPU memory drops while CPU memory rises due to offload.
Metrics glossary (CSV)
run:baselineoroptstate.step: Step index (1-based).loss: Training loss for the step.accuracy: Running training accuracy up to that step.val_accuracy: Validation accuracy (filled after eval).step_time_ms: Total wall-clock time per step.tensor_mem_mb: Estimated optimizer state size (compressed + uncompressed) in MB.materialize_ms: Time to decode and load optimizer state for the step.inner_step_ms: Time spent inside the wrapped optimizerstep().commit_ms: Time to compress and store optimizer state after the step.overhead_ms: Extra time in the wrapper beyond materialize/step/commit.cpu_mem_mb: Process RSS in MB.gpu_mem_mb: Current allocated VRAM (post-step sample) in MB.gpu_peak_mb: Peak allocated VRAM since last reset in MB.compression_active: Whether compression is active for the step.
How it works
- Virtualization:
OptimizerWrapperinterceptsstep(). - Materialize: compressed state is decoded to full precision for the chunk.
- Execute: inner optimizer runs normally.
- Commit: updated state is compressed and offloaded; FP32 copies freed.
Limitations
- Closures are not supported (step is chunked-only).
- Tested mainly on AdamW/SGD; other optimizers may need custom policies.
Future work: speed improvements are planned in upcoming releases.
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_optstate-0.1.1.tar.gz.
File metadata
- Download URL: torch_optstate-0.1.1.tar.gz
- Upload date:
- Size: 17.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.2.1 CPython/3.12.12 Linux/6.11.0-1018-azure
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
36bc3869444e6c48b869fef535a08fe23db3b2f157b6b533f598c469139d49ec
|
|
| MD5 |
1f3ff3934ce64a455daef3e76e72a732
|
|
| BLAKE2b-256 |
600c410206b907f550af93aecf2ad11f96e75440086d08c04bf634d45e8f43b1
|
File details
Details for the file torch_optstate-0.1.1-py3-none-any.whl.
File metadata
- Download URL: torch_optstate-0.1.1-py3-none-any.whl
- Upload date:
- Size: 20.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.2.1 CPython/3.12.12 Linux/6.11.0-1018-azure
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b47ba27ea706603d304483e794a8cf7b21f04456ee30a0410b3cf5a481aabed0
|
|
| MD5 |
d8eeed774c4154eaae6c8a34f9fbb9a9
|
|
| BLAKE2b-256 |
0e8c0cebf47605aa326c9c90e8162285217aee8978345ecd0b77faa6bcd6c1b9
|