Skip to main content

Drop-in memory optimizer for PyTorch training. Reduce VRAM 40-70% with 2 lines of code.

Project description

MemScale

Drop-in memory optimizer for PyTorch training. Reduce VRAM 40–70% with 2 lines of code.

PyPI version Python 3.9+ License


The problem

Training large models on GPUs hits a wall: VRAM.

  • Llama-3 8B with batch size 32 → out of memory on A100 40GB.
  • Reduce batch size to 8 → training takes 3× longer.
  • Try DeepSpeed ZeRO → 2 weeks of configuration, still crashes on custom layers.

MemScale solves this. Wrap your trainer in 2 lines, get 40–70% VRAM reduction, no code changes.

Quick start

pip install memscale
import memscale
from transformers import Trainer, TrainingArguments

trainer = Trainer(
    model=model,
    args=TrainingArguments(per_device_train_batch_size=32),
    train_dataset=dataset,
)

# Add this one line:
trainer = memscale.wrap(trainer)

trainer.train()  # 50% less VRAM, same speed

That's it. MemScale automatically:

  1. Profiles your model's memory usage per layer
  2. Decides which optimization technique fits each layer best
  3. Applies activation checkpointing, CPU offloading, or tiling — whichever is optimal
  4. Reports memory savings and throughput in real time

Benchmarks

Reproducible results on a single A100 40GB:

Model Batch Size Baseline VRAM MemScale VRAM Reduction Throughput
Llama-3 8B 32 OOM 24 GB 1.0×
Llama-3 8B 8 (baseline) 38 GB 0.33× (slower)
Mistral 7B 16 35 GB 19 GB 46% 0.97×
GPT-2 XL 8 28 GB 12 GB 57% 0.98×
BERT-Large 64 22 GB 11 GB 50% 1.00×

Run benchmarks yourself: python tests/benchmarks/run_benchmark.py

How it works

MemScale combines three techniques, choosing the best one per layer:

1. Activation Checkpointing

Don't store activations during forward pass — recompute them on backward. Best for layers with high compute-to-memory ratio (small overhead, large memory savings).

2. CPU Offloading

Move parameters to CPU RAM when not in active use, prefetch back via PCIe. Uses dedicated CUDA streams to overlap transfer with compute.

3. Activation Tiling

Split large activation tensors into chunks, process sequentially. Best for attention layers and wide feedforward networks.

The Decision Engine picks the best technique for each layer based on:

  • Memory footprint (params + activations)
  • Compute cost (FLOPs)
  • Available CPU RAM
  • PCIe bandwidth

You don't configure this. MemScale figures it out.

Usage modes

HuggingFace Trainer

import memscale
trainer = memscale.wrap(your_hf_trainer)
trainer.train()

PyTorch Lightning (coming v0.2)

import lightning as L
import memscale

trainer = L.Trainer(
    plugins=[memscale.LightningPlugin()],
)
trainer.fit(model, dataloader)

Custom training loop

import memscale

with memscale.optimize(model, optimizer) as ms:
    for batch in dataloader:
        loss = model(batch).loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Configuration

Most users don't need this. Defaults work for 90% of cases.

from memscale import wrap, Config, OptimizationMode

config = Config(
    mode=OptimizationMode.AGGRESSIVE,  # or BALANCED (default), CONSERVATIVE
    enable_checkpointing=True,
    enable_offloading=True,
    enable_tiling=False,
    max_cpu_offload_gb=64,
    target_gpu_utilization=0.85,
)

trainer = wrap(trainer, config=config)

Compatibility

Component Min Version Tested
Python 3.9 3.9, 3.10, 3.11, 3.12
PyTorch 2.1 2.1, 2.2, 2.3, 2.4
CUDA 11.8 11.8, 12.1, 12.4
GPU Compute capability 7.0+ V100, A100, H100, RTX 3090/4090
OS Linux Ubuntu 20.04, 22.04

AMD GPU support (ROCm) coming v0.3.

FAQ

Q: Does MemScale change my training results? No. All techniques are mathematically lossless — bit-exact equivalence with baseline (within FP arithmetic tolerance of 1e-6).

Q: How does this compare to DeepSpeed? DeepSpeed is powerful but requires extensive configuration and has steep learning curve. MemScale is plug-and-play. For most use cases, MemScale is enough. For large-scale distributed training (1000+ GPUs), use DeepSpeed.

Q: Will this slow down my training? Typical overhead: 0–3% on throughput. Often net faster because larger batch sizes become possible (better GPU utilization).

Q: What if my model has dynamic control flow? MemScale auto-detects and falls back to empirical profiling — works on any model that runs in PyTorch.

Q: Can I use this with FSDP / DeepSpeed? v0.1: not recommended (potential conflicts). v0.2 will add explicit compatibility layers.

Roadmap

  • v0.1 (current): Activation checkpointing, CPU offloading, HuggingFace Trainer
  • v0.2 (Q3 2026): PyTorch Lightning, multi-GPU (DDP), tiling, observability dashboard
  • v0.3 (Q4 2026): FSDP integration, AMD GPU (ROCm), JAX (experimental)
  • v1.0 (Q1 2027): Learned decision policy (RL-trained on customer data)

Contributing

We love contributions. Start by:

  1. Read CONTRIBUTING.md
  2. Check open issues
  3. Join our Discord

License

Apache 2.0 — see LICENSE.

Citation

If you use MemScale in your research, please cite:

@software{memscale2026,
  title={MemScale: Drop-in Memory Optimization for PyTorch Training},
  author={MemScale Team},
  year={2026},
  url={https://github.com/memscale/memscale}
}

Built by ML practitioners for ML practitioners. Questions? Reach us at team@memscale.dev or Discord.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

memscale-0.1.0.tar.gz (26.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

memscale-0.1.0-py3-none-any.whl (29.4 kB view details)

Uploaded Python 3

File details

Details for the file memscale-0.1.0.tar.gz.

File metadata

  • Download URL: memscale-0.1.0.tar.gz
  • Upload date:
  • Size: 26.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for memscale-0.1.0.tar.gz
Algorithm Hash digest
SHA256 66a734fb55f7c5e0b9fcdf4456b89d8ffe63d06b87c56707821a8d7b262af50a
MD5 9c102c04d8c9b7e6363a431a2d0e222f
BLAKE2b-256 8b316eca1770aa403ca84ca714257c86a002fe0cdcbe8e418bf4a4e3be2a7287

See more details on using hashes here.

File details

Details for the file memscale-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: memscale-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 29.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for memscale-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3cfe3e083e074296adb42bfe3993e4a4c2d64d585131d24f5c2ba114ab1f5181
MD5 c7946b2377199c5620d57395845ecf1f
BLAKE2b-256 cbb506201bfbb85ffbad1adb3437432b1a8fd3b5a67b2969a016a558d7b921b8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page