Skip to main content

Drop-in memory optimizer for PyTorch training. Reduce VRAM significantly with one line of code.

Project description

MemScale

Drop-in memory optimizer for PyTorch training. Reduce VRAM up to 88% with 1 line of code.

PyPI version Python 3.9+ License Tests


The problem

Training large models on GPUs hits a wall: VRAM.

  • BERT-Large with batch 16 → 17.6 GB on RTX 3090
  • 1.5B parameter model → out of memory on single 24GB GPU
  • DeepSpeed ZeRO setup → 2 weeks of configuration

MemScale solves this. Wrap your model in 1 line, get up to 88% VRAM reduction, no code changes.

Real benchmarks (validated on RTX 3090 24GB)

Model Params Baseline MemScale Reduction
BERT-Base 85M 6.39 GB 1.87 GB 70.8%
BERT-Large 302M 17.57 GB 2.11 GB 88.0%
GPT-2 Medium 302M 19.02 GB 7.16 GB 62.4%
GPT-2 Large 708M 21.78 GB 4.88 GB 77.6%
1.3B model 1.3B OOM 8.86 GB Enables training
GPT-2 XL 1.5B OOM 12.72 GB Enables training

Comparison: PyTorch native checkpointing achieves 70% on the same workloads. MemScale matches or exceeds it with the right configuration, and enables training models that PyTorch alone cannot fit.

Quick start

pip install memscale
pip install bitsandbytes  # optional, for additional 8-bit Adam savings
import memscale
from transformers import Trainer, TrainingArguments

trainer = Trainer(
    model=model,
    args=TrainingArguments(per_device_train_batch_size=16),
    train_dataset=dataset,
)

# Add this one line:
trainer = memscale.wrap(trainer)

trainer.train()  # Up to 88% less VRAM, same speed

That's it. MemScale automatically:

  1. Profiles your model's memory usage per layer
  2. Decides which optimization technique fits each layer best
  3. Applies boundary checkpointing, 8-bit Adam, mixed precision
  4. Reports memory savings and throughput in real time

Maximum reduction (combined techniques)

For maximum savings, enable all techniques:

import torch
from memscale import Config, OptimizationMode
from memscale.phase_f import apply_all_optimizations

model = YourModel()
optimizer = torch.optim.AdamW(model.parameters())

config = Config(
    mode=OptimizationMode.AGGRESSIVE,
    use_8bit_optimizer=True,    # bitsandbytes 8-bit Adam
    use_mixed_precision=True,   # BF16 on Ampere+, FP16 fallback
)

# One call applies all techniques
model, optimizer = apply_all_optimizations(model, optimizer, config)

# Train normally
for batch in dataloader:
    loss = model(batch).loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

This stack achieved 88.0% reduction on BERT-Large in our benchmarks.

How it works

MemScale combines proven memory optimization techniques and chooses what fits each layer:

Technique Saves When applied
Boundary checkpointing ~70% (activations) Transformer blocks (BertLayer, GPT2Block, TransformerEncoderLayer, ViTLayer, etc.)
8-bit Adam (bitsandbytes) ~75% (optimizer state) When use_8bit_optimizer=True and bitsandbytes installed
Mixed precision (BF16/FP16) ~50% (params/activations) When use_mixed_precision=True on Ampere+ GPUs
CPU offload Variable Large layers when checkpointing not enough

The decision engine analyzes your model and picks the right technique per layer — you don't need to configure individual layers.

Multi-GPU support

Multi-GPU training works via standard PyTorch DDP. MemScale's per-GPU optimizations apply on each GPU:

torchrun --nproc_per_node=2 your_training_script.py
import memscale
import torch.nn.parallel as parallel

model = YourModel().to(local_rank)
model, optimizer = apply_all_optimizations(model, optimizer, config)
model = parallel.DistributedDataParallel(model, device_ids=[local_rank])

# Train normally - 87% per-GPU reduction with 2x throughput

Validated on 2x RTX 3090: 1.69 GB per GPU (vs 13 GB baseline single-GPU).

Distributed sharding (research preview)

Phase G provides ZeRO-3 inspired parameter and optimizer sharding building blocks:

from memscale.distributed import (
    init_distributed,
    shard_model_parameters,
    ShardedOptimizer,
)

Note: Phase G provides the ShardedParameter and ShardedOptimizer classes with NCCL-based all-gather. Full integration with model forward/backward hooks is planned for v1.1. For production multi-GPU training requiring 95%+ reduction today, use FSDP or DeepSpeed.

Usage modes

HuggingFace Trainer

import memscale
trainer = memscale.wrap(your_hf_trainer)
trainer.train()

PyTorch Lightning

from lightning import Trainer
from memscale.integrations.lightning import MemScaleLightningCallback

trainer = Trainer(
    callbacks=[MemScaleLightningCallback()],
    max_epochs=10,
)
trainer.fit(model, dataloader)

Custom training loop

import memscale

with memscale.optimize(model, optimizer) as ms:
    for batch in dataloader:
        loss = model(batch).loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Configuration

Most users don't need this. Defaults work for 90% of cases.

from memscale import wrap, Config, OptimizationMode

config = Config(
    mode=OptimizationMode.AGGRESSIVE,  # or BALANCED (default), CONSERVATIVE
    enable_checkpointing=True,
    enable_offloading=True,
    use_8bit_optimizer=False,    # set True for max reduction
    use_mixed_precision=False,   # set True for max reduction
    target_gpu_utilization=0.85,
)

trainer = wrap(trainer, config=config)

Compatibility

Component Min Version Tested
Python 3.9 3.9, 3.10, 3.11, 3.12
PyTorch 2.1 2.1, 2.2, 2.3, 2.4
CUDA 11.8 11.8, 12.1, 12.4, 12.8
GPU Compute capability 7.0+ V100, A100, H100, RTX 3090/4090
BF16 mixed precision Compute capability 8.0+ A100, H100, RTX 3090/4090
OS Linux Ubuntu 20.04, 22.04, 24.04

AMD GPU support (ROCm) coming in a future release.

FAQ

Q: Does MemScale change my training results? The activation checkpointing and DDP techniques are mathematically lossless. BF16/FP16 mixed precision introduces small numerical differences — same as standard PyTorch AMP.

Q: How does this compare to DeepSpeed and FSDP? DeepSpeed and FSDP are powerful but require significant configuration and distributed training expertise. MemScale's value is plug-and-play: 1-line wrap with auto-detection. For 95%+ reduction in production multi-GPU setups, DeepSpeed ZeRO-3 is more mature. For single-GPU and DDP workloads, MemScale is competitive and easier to use.

Q: Will this slow down my training? Activation checkpointing adds 20-30% compute overhead (the standard tradeoff). 8-bit Adam adds ~2-5%. Net effect: training is slower per step, but you can use larger batches (better hardware utilization), so end-to-end time often improves.

Q: What if my model has custom architecture? The decision engine handles standard transformers (PyTorch native, HuggingFace BERT/GPT2/Llama/Mistral, vision transformers) automatically. Custom architectures fall back to per-module heuristics. Both are tested.

Q: Why "up to 88%" instead of a flat number? Reduction depends on model architecture, batch size, sequence length, and which techniques you enable. Our benchmarks show 62-88% on standard transformers. Smaller and older models show less; large modern models with long sequences see the most savings.

Roadmap

  • v1.0 (current): Boundary checkpointing, 8-bit Adam, BF16 mixed precision, DDP support, sharding building blocks
  • v1.1: Phase G full integration (ZeRO-3 forward/backward hooks), AMD GPU (ROCm)
  • v1.2: FSDP/DeepSpeed integration helpers, web dashboard
  • v2.0: Learned decision policy (model-specific tuning)

Architecture

MemScale's optimization happens in stages:

  1. Profiling: Static analysis via torch.fx, with empirical fallback for dynamic models
  2. Decision engine: Per-layer technique selection based on memory profile, hardware budget, and configuration
  3. Execution: Apply chosen techniques via PyTorch hooks
  4. Observation: Track memory and throughput, report to user

Source code is organized as:

memscale/
├── core/           # Profiler, decision engine, executor, config
├── techniques/     # Checkpointing, 8-bit optimizer, mixed precision
├── distributed/   # FSDP integration + ZeRO-3 inspired sharding (Phase G)
├── integrations/   # HuggingFace, Lightning adapters
└── phase_f.py      # apply_all_optimizations one-line API

Contributing

Issues and PRs welcome. Please include:

  1. Minimal reproducible example
  2. Hardware (GPU model, VRAM)
  3. PyTorch version
  4. Output of memscale.profile_model(model) if relevant

License

Apache 2.0 — see LICENSE.

Citation

If you use MemScale in your research, please cite:

@software{memscale2026,
  title={MemScale: Drop-in Memory Optimization for PyTorch Training},
  author={MemScale Team},
  year={2026},
  url={https://github.com/MrGinkaku/MemScale}
}

Built for ML practitioners. Questions? team@memscale.id

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

memscale-1.0.2-cp312-cp312-win_amd64.whl (1.4 MB view details)

Uploaded CPython 3.12Windows x86-64

memscale-1.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.8 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

memscale-1.0.2-cp312-cp312-macosx_11_0_arm64.whl (1.4 MB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

memscale-1.0.2-cp311-cp311-win_amd64.whl (1.4 MB view details)

Uploaded CPython 3.11Windows x86-64

memscale-1.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.2 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

memscale-1.0.2-cp311-cp311-macosx_11_0_arm64.whl (1.4 MB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

memscale-1.0.2-cp310-cp310-win_amd64.whl (1.4 MB view details)

Uploaded CPython 3.10Windows x86-64

memscale-1.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.7 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

memscale-1.0.2-cp310-cp310-macosx_11_0_arm64.whl (1.4 MB view details)

Uploaded CPython 3.10macOS 11.0+ ARM64

File details

Details for the file memscale-1.0.2-cp312-cp312-win_amd64.whl.

File metadata

  • Download URL: memscale-1.0.2-cp312-cp312-win_amd64.whl
  • Upload date:
  • Size: 1.4 MB
  • Tags: CPython 3.12, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for memscale-1.0.2-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 5d63d9ee3fdb5c1f775b4d6fec458486146db4e6ffc1bff41aba00480b422982
MD5 4945e9495c8da14036778bbf4cfd23fc
BLAKE2b-256 8e71ded782350abd4623ac92d7ecbca3cd355ce24365179afb1aff7014934702

See more details on using hashes here.

File details

Details for the file memscale-1.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for memscale-1.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 2eaec2b3b60eedff2c1236bd13a3df2efd3e75e5d4589aa99d5bc85948a2f8a2
MD5 c57d68c4a4e53fafadd74def7c666923
BLAKE2b-256 722a52374569420846b55ede2d63cad5d593320196165a40de8414763a526dd2

See more details on using hashes here.

File details

Details for the file memscale-1.0.2-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for memscale-1.0.2-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 de6cd9d646843285667219b6d3617a4084fa419abc8df204d7369d1733a81106
MD5 06c6b9a89ec899d7b2bcd1f006f55d48
BLAKE2b-256 3fa05ebc5e13514dcf08db745acd79a2f2e142aa22460c90f6ec8ab0a78db903

See more details on using hashes here.

File details

Details for the file memscale-1.0.2-cp311-cp311-win_amd64.whl.

File metadata

  • Download URL: memscale-1.0.2-cp311-cp311-win_amd64.whl
  • Upload date:
  • Size: 1.4 MB
  • Tags: CPython 3.11, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for memscale-1.0.2-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 6843ab1f44a412449d704554f122cdf649836ce36584d4a63d64250ad39239da
MD5 47e30831644e43cf0b3f901a3568301e
BLAKE2b-256 87bd6893f60d0cf7a9f68c8606f3a3e4408d5b723aadd2e5c2dbb2c42c3f547b

See more details on using hashes here.

File details

Details for the file memscale-1.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for memscale-1.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 0cde156a2dc299c6a27e48a3e5c89256db618e358b66c6eb354c325efd6f7df1
MD5 99408b8cb0daea2080d113b84aa77751
BLAKE2b-256 2ddf17f545c07c3cbcc12d138b9e0e1b184a836482a609245f6d79ed12df7dad

See more details on using hashes here.

File details

Details for the file memscale-1.0.2-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for memscale-1.0.2-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 87f09b13e4233cb3e969877189d50d1331c04e2f79452ebe8e0180b1c61db24e
MD5 3fd215efd339948e705bd898462b59fd
BLAKE2b-256 3305e2fe15fe727ea3a7bd6f6f0efccd7f1db23fe9087d672d977dbb49a084a0

See more details on using hashes here.

File details

Details for the file memscale-1.0.2-cp310-cp310-win_amd64.whl.

File metadata

  • Download URL: memscale-1.0.2-cp310-cp310-win_amd64.whl
  • Upload date:
  • Size: 1.4 MB
  • Tags: CPython 3.10, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for memscale-1.0.2-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 8f9278aff43c4d14cd3372ad80673813f50a5e20082773dc80ff11dc2167af8e
MD5 3d88798e30ddfe671b69421c82057f62
BLAKE2b-256 00d20eeaa8ee52e749c803b67c8f2ec6f3abf17c9da00f60a357836e79382596

See more details on using hashes here.

File details

Details for the file memscale-1.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for memscale-1.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 d83ed45fe128f53e70e42810bbf6cfa65c6f407adbd51063c79e459642423eb7
MD5 0a6fd9f93ba2581fd7a7f16fc0e7842e
BLAKE2b-256 ded54a325bc5e9bb7a0296f97c48d128beec3e021f3928e04bd27dc6750cf180

See more details on using hashes here.

File details

Details for the file memscale-1.0.2-cp310-cp310-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for memscale-1.0.2-cp310-cp310-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 6d449e1a0164ce2da221fcd83aff905d916d0f204d0d0e408e7b239ae05c05e5
MD5 cdc24c7d565e185678af1a7925686833
BLAKE2b-256 ac26628adf289e3549b8443c5018fabf7a8d12b6251bc41dd59a7e72aade3c6e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page