Skip to main content

Optimizer state virtualization and compression for PyTorch

Project description

torch-optstate

Optimizer State Virtualization for PyTorch

torch-optstate is a core infrastructure library that wraps existing PyTorch optimizers (like Adam, AdamW, SGD) to virtualize their state. It enables significant memory savings by compressing optimizer states (e.g., momentum) when they are not in use, and seamlessly materializing them on-the-fly during optimization steps.

❓ Why use this?

Training large models is often memory-bound. While parameters and gradients take up space, optimizer states (like momentum and variance in Adam) can consume 2x to 3x the memory of the model parameters themselves.

torch-optstate solves this by:

  1. Reducing Memory Footprint: Compresses optimizer state by 25% to 75% with minimal accuracy loss.
  2. Drop-in Compatibility: Works with your existing torch.optim optimizers and training loops.
  3. CPU Offloading Ready: Designed to manage state on CPU, freeing up precious GPU memory (future GPU support planned).
  4. Policy-Driven: You control the trade-off between precision, speed, and memory.

Benchmark Highlights (CPU)

Model Optimizer Policy State Size Reduction Status
MLP AdamW Baseline 76.45 MB - OK
Int8 Momentum 47.78 MB 37.5% OK
Mixed FP16 57.34 MB 25.0% OK
SGD SGD+Mom Baseline 38.23 MB - OK
Int8 Momentum 9.56 MB 75.0% OK
Transformer AdamW Baseline 2.60 MB - OK
Int8 Momentum 1.62 MB 37.5% OK

Benchmarks run on CPU. Int8 Momentum policy keeps variance in FP32 for stability.

📦 Installation

pip install torch-optstate

(Note: This package is currently a research preview.)

🛠️ Usage

1. Basic Usage (Drop-in)

Simply wrap your existing optimizer. By default, it uses a WarmupPolicy that keeps state in FP32 for a few steps before compressing momentum to INT8.

import torch
from torch.optim import AdamW
from torch_optstate import wrap

model = torch.nn.Linear(10, 1)
optimizer = AdamW(model.parameters(), lr=1e-3)

# Wrap the optimizer
# This will automatically manage state compression
optimizer = wrap(optimizer)

# Training loop (standard PyTorch)
for input, target in dataset:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

2. Custom Policies

You can define custom policies to control compression behavior.

Int8 Momentum (Aggressive Compression):

from torch_optstate import wrap, WarmupPolicy

# Keep in FP32 for 1000 steps, then compress momentum to INT8
# Variance (if present) stays in FP32 for stability.
policy = WarmupPolicy(warmup_steps=1000)
optimizer = wrap(optimizer, policy=policy)

Mixed Precision (FP16 Momentum, FP32 Variance):

from torch_optstate import wrap, ConfigurablePolicy, FP16Codec, FP32Codec

# Define a policy that stores 'exp_avg' in FP16 and 'exp_avg_sq' in FP32
policy = ConfigurablePolicy(
    codecs_map={
        'exp_avg': FP16Codec(),
        'momentum_buffer': FP16Codec()
    },
    default_codec=FP32Codec() # Fallback for variance
)
optimizer = wrap(optimizer, policy=policy)

3. Advanced Configuration

You can combine warmup with custom codecs using ConfigurablePolicy.

from torch_optstate import wrap, ConfigurablePolicy, FP16Codec, FP32Codec

# Warmup for 100 steps, then switch to FP16 for momentum
policy = ConfigurablePolicy(
    codecs_map={'exp_avg': FP16Codec()},
    default_codec=FP32Codec(),
    warmup_steps=100
)
optimizer = wrap(optimizer, policy=policy)

🧠 How It Works

  1. Virtualization: The OptimizerWrapper intercepts step() calls.
  2. Materialization: Before the inner optimizer runs, compressed state is decoded to full precision (FP32).
  3. Execution: The inner optimizer performs the update using standard PyTorch kernels.
  4. Commit: After the update, the new state is compressed (e.g., quantized) and stored in the StateStore, and the full-precision state is freed.

⚠️ Limitations

  • CPU Only: Currently, state management is optimized for CPU. GPU support is planned.
  • Step Overhead: Decompression/Compression adds overhead to the step() call. This is often negligible compared to the forward/backward pass of large models.
  • Optimizer Support: Tested primarily with AdamW and SGD. Other optimizers should work but may require custom policies if they use non-standard state keys.

🤝 Contributing

Contributions are welcome! Please check the tests/ folder for coverage requirements.

📄 License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_optstate-0.1.0.tar.gz (10.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_optstate-0.1.0-py3-none-any.whl (12.3 kB view details)

Uploaded Python 3

File details

Details for the file torch_optstate-0.1.0.tar.gz.

File metadata

  • Download URL: torch_optstate-0.1.0.tar.gz
  • Upload date:
  • Size: 10.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.2.1 CPython/3.13.2 Linux/5.15.167.4-microsoft-standard-WSL2

File hashes

Hashes for torch_optstate-0.1.0.tar.gz
Algorithm Hash digest
SHA256 c53e0d846423822ad94f6cc548b67fd3b053e84071cad3daab6260c5e7bec958
MD5 3242b291ff4f631045ed99c325bafbde
BLAKE2b-256 80579d83761b8a516cf5ec76a38d79150a44c1820e659eb002411df2a3b115cd

See more details on using hashes here.

File details

Details for the file torch_optstate-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: torch_optstate-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 12.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.2.1 CPython/3.13.2 Linux/5.15.167.4-microsoft-standard-WSL2

File hashes

Hashes for torch_optstate-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3a7a671040a5e8509c32ac14a4e077aa8891f0c6956998e1e7df47564d6fa80f
MD5 58669be0d407da69787ebfcc3dd3ea5f
BLAKE2b-256 f399c87dc77750afe65fbb0d7dc5d681fdcbc04f1d541744017a1cdd9ac1e0bc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page