Optimizer state virtualization and compression for PyTorch
Project description
torch-optstate
Optimizer State Virtualization for PyTorch
torch-optstate is a core infrastructure library that wraps existing PyTorch optimizers (like Adam, AdamW, SGD) to virtualize their state. It enables significant memory savings by compressing optimizer states (e.g., momentum) when they are not in use, and seamlessly materializing them on-the-fly during optimization steps.
❓ Why use this?
Training large models is often memory-bound. While parameters and gradients take up space, optimizer states (like momentum and variance in Adam) can consume 2x to 3x the memory of the model parameters themselves.
torch-optstate solves this by:
- Reducing Memory Footprint: Compresses optimizer state by 25% to 75% with minimal accuracy loss.
- Drop-in Compatibility: Works with your existing
torch.optimoptimizers and training loops. - CPU Offloading Ready: Designed to manage state on CPU, freeing up precious GPU memory (future GPU support planned).
- Policy-Driven: You control the trade-off between precision, speed, and memory.
Benchmark Highlights (CPU)
| Model | Optimizer | Policy | State Size | Reduction | Status |
|---|---|---|---|---|---|
| MLP | AdamW | Baseline | 76.45 MB | - | OK |
| Int8 Momentum | 47.78 MB | 37.5% | OK | ||
| Mixed FP16 | 57.34 MB | 25.0% | OK | ||
| SGD | SGD+Mom | Baseline | 38.23 MB | - | OK |
| Int8 Momentum | 9.56 MB | 75.0% | OK | ||
| Transformer | AdamW | Baseline | 2.60 MB | - | OK |
| Int8 Momentum | 1.62 MB | 37.5% | OK |
Benchmarks run on CPU. Int8 Momentum policy keeps variance in FP32 for stability.
📦 Installation
pip install torch-optstate
(Note: This package is currently a research preview.)
🛠️ Usage
1. Basic Usage (Drop-in)
Simply wrap your existing optimizer. By default, it uses a WarmupPolicy that keeps state in FP32 for a few steps before compressing momentum to INT8.
import torch
from torch.optim import AdamW
from torch_optstate import wrap
model = torch.nn.Linear(10, 1)
optimizer = AdamW(model.parameters(), lr=1e-3)
# Wrap the optimizer
# This will automatically manage state compression
optimizer = wrap(optimizer)
# Training loop (standard PyTorch)
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
2. Custom Policies
You can define custom policies to control compression behavior.
Int8 Momentum (Aggressive Compression):
from torch_optstate import wrap, WarmupPolicy
# Keep in FP32 for 1000 steps, then compress momentum to INT8
# Variance (if present) stays in FP32 for stability.
policy = WarmupPolicy(warmup_steps=1000)
optimizer = wrap(optimizer, policy=policy)
Mixed Precision (FP16 Momentum, FP32 Variance):
from torch_optstate import wrap, ConfigurablePolicy, FP16Codec, FP32Codec
# Define a policy that stores 'exp_avg' in FP16 and 'exp_avg_sq' in FP32
policy = ConfigurablePolicy(
codecs_map={
'exp_avg': FP16Codec(),
'momentum_buffer': FP16Codec()
},
default_codec=FP32Codec() # Fallback for variance
)
optimizer = wrap(optimizer, policy=policy)
3. Advanced Configuration
You can combine warmup with custom codecs using ConfigurablePolicy.
from torch_optstate import wrap, ConfigurablePolicy, FP16Codec, FP32Codec
# Warmup for 100 steps, then switch to FP16 for momentum
policy = ConfigurablePolicy(
codecs_map={'exp_avg': FP16Codec()},
default_codec=FP32Codec(),
warmup_steps=100
)
optimizer = wrap(optimizer, policy=policy)
🧠 How It Works
- Virtualization: The
OptimizerWrapperinterceptsstep()calls. - Materialization: Before the inner optimizer runs, compressed state is decoded to full precision (FP32).
- Execution: The inner optimizer performs the update using standard PyTorch kernels.
- Commit: After the update, the new state is compressed (e.g., quantized) and stored in the
StateStore, and the full-precision state is freed.
⚠️ Limitations
- CPU Only: Currently, state management is optimized for CPU. GPU support is planned.
- Step Overhead: Decompression/Compression adds overhead to the
step()call. This is often negligible compared to the forward/backward pass of large models. - Optimizer Support: Tested primarily with AdamW and SGD. Other optimizers should work but may require custom policies if they use non-standard state keys.
🤝 Contributing
Contributions are welcome! Please check the tests/ folder for coverage requirements.
📄 License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_optstate-0.1.0.tar.gz.
File metadata
- Download URL: torch_optstate-0.1.0.tar.gz
- Upload date:
- Size: 10.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.2.1 CPython/3.13.2 Linux/5.15.167.4-microsoft-standard-WSL2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c53e0d846423822ad94f6cc548b67fd3b053e84071cad3daab6260c5e7bec958
|
|
| MD5 |
3242b291ff4f631045ed99c325bafbde
|
|
| BLAKE2b-256 |
80579d83761b8a516cf5ec76a38d79150a44c1820e659eb002411df2a3b115cd
|
File details
Details for the file torch_optstate-0.1.0-py3-none-any.whl.
File metadata
- Download URL: torch_optstate-0.1.0-py3-none-any.whl
- Upload date:
- Size: 12.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.2.1 CPython/3.13.2 Linux/5.15.167.4-microsoft-standard-WSL2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3a7a671040a5e8509c32ac14a4e077aa8891f0c6956998e1e7df47564d6fa80f
|
|
| MD5 |
58669be0d407da69787ebfcc3dd3ea5f
|
|
| BLAKE2b-256 |
f399c87dc77750afe65fbb0d7dc5d681fdcbc04f1d541744017a1cdd9ac1e0bc
|