Drop-in memory optimizer for PyTorch training. Reduce VRAM 40-70% with 2 lines of code.
Project description
MemScale
Drop-in memory optimizer for PyTorch training. Reduce VRAM 40–70% with 2 lines of code.
The problem
Training large models on GPUs hits a wall: VRAM.
- Llama-3 8B with batch size 32 → out of memory on A100 40GB.
- Reduce batch size to 8 → training takes 3× longer.
- Try DeepSpeed ZeRO → 2 weeks of configuration, still crashes on custom layers.
MemScale solves this. Wrap your trainer in 2 lines, get 40–70% VRAM reduction, no code changes.
Quick start
pip install memscale
import memscale
from transformers import Trainer, TrainingArguments
trainer = Trainer(
model=model,
args=TrainingArguments(per_device_train_batch_size=32),
train_dataset=dataset,
)
# Add this one line:
trainer = memscale.wrap(trainer)
trainer.train() # 50% less VRAM, same speed
That's it. MemScale automatically:
- Profiles your model's memory usage per layer
- Decides which optimization technique fits each layer best
- Applies activation checkpointing, CPU offloading, or tiling — whichever is optimal
- Reports memory savings and throughput in real time
Benchmarks
Reproducible results on a single A100 40GB:
| Model | Batch Size | Baseline VRAM | MemScale VRAM | Reduction | Throughput |
|---|---|---|---|---|---|
| Llama-3 8B | 32 | OOM | 24 GB | — | 1.0× |
| Llama-3 8B | 8 (baseline) | 38 GB | — | — | 0.33× (slower) |
| Mistral 7B | 16 | 35 GB | 19 GB | 46% | 0.97× |
| GPT-2 XL | 8 | 28 GB | 12 GB | 57% | 0.98× |
| BERT-Large | 64 | 22 GB | 11 GB | 50% | 1.00× |
Run benchmarks yourself: python tests/benchmarks/run_benchmark.py
How it works
MemScale combines three techniques, choosing the best one per layer:
1. Activation Checkpointing
Don't store activations during forward pass — recompute them on backward. Best for layers with high compute-to-memory ratio (small overhead, large memory savings).
2. CPU Offloading
Move parameters to CPU RAM when not in active use, prefetch back via PCIe. Uses dedicated CUDA streams to overlap transfer with compute.
3. Activation Tiling
Split large activation tensors into chunks, process sequentially. Best for attention layers and wide feedforward networks.
The Decision Engine picks the best technique for each layer based on:
- Memory footprint (params + activations)
- Compute cost (FLOPs)
- Available CPU RAM
- PCIe bandwidth
You don't configure this. MemScale figures it out.
Usage modes
HuggingFace Trainer
import memscale
trainer = memscale.wrap(your_hf_trainer)
trainer.train()
PyTorch Lightning (coming v0.2)
import lightning as L
import memscale
trainer = L.Trainer(
plugins=[memscale.LightningPlugin()],
)
trainer.fit(model, dataloader)
Custom training loop
import memscale
with memscale.optimize(model, optimizer) as ms:
for batch in dataloader:
loss = model(batch).loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
Configuration
Most users don't need this. Defaults work for 90% of cases.
from memscale import wrap, Config, OptimizationMode
config = Config(
mode=OptimizationMode.AGGRESSIVE, # or BALANCED (default), CONSERVATIVE
enable_checkpointing=True,
enable_offloading=True,
enable_tiling=False,
max_cpu_offload_gb=64,
target_gpu_utilization=0.85,
)
trainer = wrap(trainer, config=config)
Compatibility
| Component | Min Version | Tested |
|---|---|---|
| Python | 3.9 | 3.9, 3.10, 3.11, 3.12 |
| PyTorch | 2.1 | 2.1, 2.2, 2.3, 2.4 |
| CUDA | 11.8 | 11.8, 12.1, 12.4 |
| GPU | Compute capability 7.0+ | V100, A100, H100, RTX 3090/4090 |
| OS | Linux | Ubuntu 20.04, 22.04 |
AMD GPU support (ROCm) coming v0.3.
FAQ
Q: Does MemScale change my training results? No. All techniques are mathematically lossless — bit-exact equivalence with baseline (within FP arithmetic tolerance of 1e-6).
Q: How does this compare to DeepSpeed? DeepSpeed is powerful but requires extensive configuration and has steep learning curve. MemScale is plug-and-play. For most use cases, MemScale is enough. For large-scale distributed training (1000+ GPUs), use DeepSpeed.
Q: Will this slow down my training? Typical overhead: 0–3% on throughput. Often net faster because larger batch sizes become possible (better GPU utilization).
Q: What if my model has dynamic control flow? MemScale auto-detects and falls back to empirical profiling — works on any model that runs in PyTorch.
Q: Can I use this with FSDP / DeepSpeed? v0.1: not recommended (potential conflicts). v0.2 will add explicit compatibility layers.
Roadmap
- v0.1 (current): Activation checkpointing, CPU offloading, HuggingFace Trainer
- v0.2 (Q3 2026): PyTorch Lightning, multi-GPU (DDP), tiling, observability dashboard
- v0.3 (Q4 2026): FSDP integration, AMD GPU (ROCm), JAX (experimental)
- v1.0 (Q1 2027): Learned decision policy (RL-trained on customer data)
Contributing
We love contributions. Start by:
- Read CONTRIBUTING.md
- Check open issues
- Join our Discord
License
Apache 2.0 — see LICENSE.
Citation
If you use MemScale in your research, please cite:
@software{memscale2026,
title={MemScale: Drop-in Memory Optimization for PyTorch Training},
author={MemScale Team},
year={2026},
url={https://github.com/memscale/memscale}
}
Built by ML practitioners for ML practitioners. Questions? Reach us at team@memscale.dev or Discord.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file memscale-0.1.0.tar.gz.
File metadata
- Download URL: memscale-0.1.0.tar.gz
- Upload date:
- Size: 26.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
66a734fb55f7c5e0b9fcdf4456b89d8ffe63d06b87c56707821a8d7b262af50a
|
|
| MD5 |
9c102c04d8c9b7e6363a431a2d0e222f
|
|
| BLAKE2b-256 |
8b316eca1770aa403ca84ca714257c86a002fe0cdcbe8e418bf4a4e3be2a7287
|
File details
Details for the file memscale-0.1.0-py3-none-any.whl.
File metadata
- Download URL: memscale-0.1.0-py3-none-any.whl
- Upload date:
- Size: 29.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3cfe3e083e074296adb42bfe3993e4a4c2d64d585131d24f5c2ba114ab1f5181
|
|
| MD5 |
c7946b2377199c5620d57395845ecf1f
|
|
| BLAKE2b-256 |
cbb506201bfbb85ffbad1adb3437432b1a8fd3b5a67b2969a016a558d7b921b8
|