Skip to main content

Drop-in memory optimizer for PyTorch training. Reduce VRAM significantly with one line of code.

Project description

MemScale

Drop-in memory optimizer for PyTorch training. Reduce VRAM up to 88% with 1 line of code.

PyPI version Python 3.9+ License Tests


The problem

Training large models on GPUs hits a wall: VRAM.

  • BERT-Large with batch 16 → 17.6 GB on RTX 3090
  • 1.5B parameter model → out of memory on single 24GB GPU
  • DeepSpeed ZeRO setup → 2 weeks of configuration

MemScale solves this. Wrap your model in 1 line, get up to 88% VRAM reduction, no code changes.

Real benchmarks (validated on RTX 3090 24GB)

Model Params Baseline MemScale Reduction
BERT-Base 85M 6.39 GB 1.87 GB 70.8%
BERT-Large 302M 17.57 GB 2.11 GB 88.0%
GPT-2 Medium 302M 19.02 GB 7.16 GB 62.4%
GPT-2 Large 708M 21.78 GB 4.88 GB 77.6%
1.3B model 1.3B OOM 8.86 GB Enables training
GPT-2 XL 1.5B OOM 12.72 GB Enables training

Comparison: PyTorch native checkpointing achieves 70% on the same workloads. MemScale matches or exceeds it with the right configuration, and enables training models that PyTorch alone cannot fit.

Quick start

pip install memscale
pip install bitsandbytes  # optional, for additional 8-bit Adam savings
import memscale
from transformers import Trainer, TrainingArguments

trainer = Trainer(
    model=model,
    args=TrainingArguments(per_device_train_batch_size=16),
    train_dataset=dataset,
)

# Add this one line:
trainer = memscale.wrap(trainer)

trainer.train()  # Up to 88% less VRAM, same speed

That's it. MemScale automatically:

  1. Profiles your model's memory usage per layer
  2. Decides which optimization technique fits each layer best
  3. Applies boundary checkpointing, 8-bit Adam, mixed precision
  4. Reports memory savings and throughput in real time

Maximum reduction (combined techniques)

For maximum savings, enable all techniques:

import torch
from memscale import Config, OptimizationMode
from memscale.phase_f import apply_all_optimizations

model = YourModel()
optimizer = torch.optim.AdamW(model.parameters())

config = Config(
    mode=OptimizationMode.AGGRESSIVE,
    use_8bit_optimizer=True,    # bitsandbytes 8-bit Adam
    use_mixed_precision=True,   # BF16 on Ampere+, FP16 fallback
)

# One call applies all techniques
model, optimizer = apply_all_optimizations(model, optimizer, config)

# Train normally
for batch in dataloader:
    loss = model(batch).loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

This stack achieved 88.0% reduction on BERT-Large in our benchmarks.

How it works

MemScale combines proven memory optimization techniques and chooses what fits each layer:

Technique Saves When applied
Boundary checkpointing ~70% (activations) Transformer blocks (BertLayer, GPT2Block, TransformerEncoderLayer, ViTLayer, etc.)
8-bit Adam (bitsandbytes) ~75% (optimizer state) When use_8bit_optimizer=True and bitsandbytes installed
Mixed precision (BF16/FP16) ~50% (params/activations) When use_mixed_precision=True on Ampere+ GPUs
CPU offload Variable Large layers when checkpointing not enough

The decision engine analyzes your model and picks the right technique per layer — you don't need to configure individual layers.

Multi-GPU support

Multi-GPU training works via standard PyTorch DDP. MemScale's per-GPU optimizations apply on each GPU:

torchrun --nproc_per_node=2 your_training_script.py
import memscale
import torch.nn.parallel as parallel

model = YourModel().to(local_rank)
model, optimizer = apply_all_optimizations(model, optimizer, config)
model = parallel.DistributedDataParallel(model, device_ids=[local_rank])

# Train normally - 87% per-GPU reduction with 2x throughput

Validated on 2x RTX 3090: 1.69 GB per GPU (vs 13 GB baseline single-GPU).

Distributed sharding (research preview)

Phase G provides ZeRO-3 inspired parameter and optimizer sharding building blocks:

from memscale.distributed import (
    init_distributed,
    shard_model_parameters,
    ShardedOptimizer,
)

Note: Phase G provides the ShardedParameter and ShardedOptimizer classes with NCCL-based all-gather. Full integration with model forward/backward hooks is planned for v1.1. For production multi-GPU training requiring 95%+ reduction today, use FSDP or DeepSpeed.

Usage modes

HuggingFace Trainer

import memscale
trainer = memscale.wrap(your_hf_trainer)
trainer.train()

PyTorch Lightning

from lightning import Trainer
from memscale.integrations.lightning import MemScaleLightningCallback

trainer = Trainer(
    callbacks=[MemScaleLightningCallback()],
    max_epochs=10,
)
trainer.fit(model, dataloader)

Custom training loop

import memscale

with memscale.optimize(model, optimizer) as ms:
    for batch in dataloader:
        loss = model(batch).loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Configuration

Most users don't need this. Defaults work for 90% of cases.

from memscale import wrap, Config, OptimizationMode

config = Config(
    mode=OptimizationMode.AGGRESSIVE,  # or BALANCED (default), CONSERVATIVE
    enable_checkpointing=True,
    enable_offloading=True,
    use_8bit_optimizer=False,    # set True for max reduction
    use_mixed_precision=False,   # set True for max reduction
    target_gpu_utilization=0.85,
)

trainer = wrap(trainer, config=config)

Compatibility

Component Min Version Tested
Python 3.9 3.9, 3.10, 3.11, 3.12
PyTorch 2.1 2.1, 2.2, 2.3, 2.4
CUDA 11.8 11.8, 12.1, 12.4, 12.8
GPU Compute capability 7.0+ V100, A100, H100, RTX 3090/4090
BF16 mixed precision Compute capability 8.0+ A100, H100, RTX 3090/4090
OS Linux Ubuntu 20.04, 22.04, 24.04

AMD GPU support (ROCm) coming in a future release.

FAQ

Q: Does MemScale change my training results? The activation checkpointing and DDP techniques are mathematically lossless. BF16/FP16 mixed precision introduces small numerical differences — same as standard PyTorch AMP.

Q: How does this compare to DeepSpeed and FSDP? DeepSpeed and FSDP are powerful but require significant configuration and distributed training expertise. MemScale's value is plug-and-play: 1-line wrap with auto-detection. For 95%+ reduction in production multi-GPU setups, DeepSpeed ZeRO-3 is more mature. For single-GPU and DDP workloads, MemScale is competitive and easier to use.

Q: Will this slow down my training? Activation checkpointing adds 20-30% compute overhead (the standard tradeoff). 8-bit Adam adds ~2-5%. Net effect: training is slower per step, but you can use larger batches (better hardware utilization), so end-to-end time often improves.

Q: What if my model has custom architecture? The decision engine handles standard transformers (PyTorch native, HuggingFace BERT/GPT2/Llama/Mistral, vision transformers) automatically. Custom architectures fall back to per-module heuristics. Both are tested.

Q: Why "up to 88%" instead of a flat number? Reduction depends on model architecture, batch size, sequence length, and which techniques you enable. Our benchmarks show 62-88% on standard transformers. Smaller and older models show less; large modern models with long sequences see the most savings.

Roadmap

  • v1.0 (current): Boundary checkpointing, 8-bit Adam, BF16 mixed precision, DDP support, sharding building blocks
  • v1.1: Phase G full integration (ZeRO-3 forward/backward hooks), AMD GPU (ROCm)
  • v1.2: FSDP/DeepSpeed integration helpers, web dashboard
  • v2.0: Learned decision policy (model-specific tuning)

Architecture

MemScale's optimization happens in stages:

  1. Profiling: Static analysis via torch.fx, with empirical fallback for dynamic models
  2. Decision engine: Per-layer technique selection based on memory profile, hardware budget, and configuration
  3. Execution: Apply chosen techniques via PyTorch hooks
  4. Observation: Track memory and throughput, report to user

Source code is organized as:

memscale/
├── core/           # Profiler, decision engine, executor, config
├── techniques/     # Checkpointing, 8-bit optimizer, mixed precision
├── distributed/   # FSDP integration + ZeRO-3 inspired sharding (Phase G)
├── integrations/   # HuggingFace, Lightning adapters
└── phase_f.py      # apply_all_optimizations one-line API

Contributing

Issues and PRs welcome. Please include:

  1. Minimal reproducible example
  2. Hardware (GPU model, VRAM)
  3. PyTorch version
  4. Output of memscale.profile_model(model) if relevant

License

Proprietary. Full terms: contact team@memscale.id or visit memscale.id.

Citation

If you use MemScale in your research, please cite:

@software{memscale2026,
  title={MemScale: Drop-in Memory Optimization for PyTorch Training},
  author={MemScale Team},
  year={2026},
  url={https://github.com/MrGinkaku/MemScale}
}

Built for ML practitioners. Questions? team@memscale.id

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

memscale-1.0.3-cp312-cp312-win_amd64.whl (1.3 MB view details)

Uploaded CPython 3.12Windows x86-64

memscale-1.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.6 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

memscale-1.0.3-cp312-cp312-macosx_11_0_arm64.whl (1.4 MB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

memscale-1.0.3-cp311-cp311-win_amd64.whl (1.4 MB view details)

Uploaded CPython 3.11Windows x86-64

memscale-1.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.0 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

memscale-1.0.3-cp311-cp311-macosx_11_0_arm64.whl (1.4 MB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

memscale-1.0.3-cp310-cp310-win_amd64.whl (1.4 MB view details)

Uploaded CPython 3.10Windows x86-64

memscale-1.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.5 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

memscale-1.0.3-cp310-cp310-macosx_11_0_arm64.whl (1.4 MB view details)

Uploaded CPython 3.10macOS 11.0+ ARM64

File details

Details for the file memscale-1.0.3-cp312-cp312-win_amd64.whl.

File metadata

  • Download URL: memscale-1.0.3-cp312-cp312-win_amd64.whl
  • Upload date:
  • Size: 1.3 MB
  • Tags: CPython 3.12, Windows x86-64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for memscale-1.0.3-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 03313a82f359412956b10b792d14cf3b4febbb7895c036c186e14845b4440102
MD5 cde1047cd26ad40d1b8b9da6f9ee8149
BLAKE2b-256 7801ba5bae3ec396099911870c466b3b40eb614820821b42497b6d2b2bc0de89

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.3-cp312-cp312-win_amd64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for memscale-1.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 d021ee50a7fa8c4be887b81fd5a3d70db11ce6b03966d1c4999afc4b1358f264
MD5 42a2cd8f2d582ec3f0186491adcc13f0
BLAKE2b-256 a0f255c74295f455fe6d65863701bdf1f1c3784b128a149d91ef8f00249c60a0

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.0.3-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for memscale-1.0.3-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 ffd4ec9a90d9eb86057ebab14ca6551cc844fe15231f28c9fcf41097ff81e0ac
MD5 91d84dd50f1837b46edd8a217c91ff8a
BLAKE2b-256 18276015b3c11448cb40047aa341000729f095776eba612dd5785ddeaa70c36d

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.3-cp312-cp312-macosx_11_0_arm64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.0.3-cp311-cp311-win_amd64.whl.

File metadata

  • Download URL: memscale-1.0.3-cp311-cp311-win_amd64.whl
  • Upload date:
  • Size: 1.4 MB
  • Tags: CPython 3.11, Windows x86-64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for memscale-1.0.3-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 2fbacb395cc320cc28aeb21fafe9066b11601d724a4b2433c0a622dde63a5a52
MD5 2b07c1d325eb6209da8590b414bab036
BLAKE2b-256 b310cd1d1f03ee726ca65dfa95b48644ba9351468dd3223a485a0153f2731d1c

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.3-cp311-cp311-win_amd64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for memscale-1.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 943a668586aa4a427a45f71f305d1be3398224e85db0f80d688288f1876f208b
MD5 929ae5e15b2dcd68ae4fbcc85a28555a
BLAKE2b-256 5799a2c3ea4f7fb51d8fa833f049c590e40009420eaece076c12b3a453190950

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.0.3-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for memscale-1.0.3-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 ef2c7bb2e9fe0f85e06b85ec746f532d76b8ce4eb088c97777b632e3e3239752
MD5 70789b83f71ecc00c14f3ab512da8d33
BLAKE2b-256 12b6c6a9d970bbe83c7d7b6c19a472ea42d0b8f0a592369fa1ac66525b8e6d96

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.3-cp311-cp311-macosx_11_0_arm64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.0.3-cp310-cp310-win_amd64.whl.

File metadata

  • Download URL: memscale-1.0.3-cp310-cp310-win_amd64.whl
  • Upload date:
  • Size: 1.4 MB
  • Tags: CPython 3.10, Windows x86-64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for memscale-1.0.3-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 86572b4e98e14b2c0426c09b3c91ce71e965e671b834cc2cffbb5e06082d2249
MD5 e622575ccf5a53964d45993e4ebddda2
BLAKE2b-256 3fc8ca1b6699ab445704d103b22ae3c1a06ac776b29cc733eeedee3b80fd9531

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.3-cp310-cp310-win_amd64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for memscale-1.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 b69b7352bab0fa61f5c2114853fabed4a52b873b0e3a84deabcea2b43a807651
MD5 de4c857180c04dc0c22596d4ef3ff362
BLAKE2b-256 5fe4c017ab83eac7e0b8def161886627803ba6490a602c8c2e481d3fa3f55967

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.0.3-cp310-cp310-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for memscale-1.0.3-cp310-cp310-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 97b8805b5f5b96b68711bd189d0315aef4f969fdfc069bb3ba8dd61b0ab5025e
MD5 d4dc1cf49459f8acd63e09c6ca084a4e
BLAKE2b-256 cb31073d9bee2175a6e83494e5217e58368aee3cf463e23d6906da40a058ce58

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.3-cp310-cp310-macosx_11_0_arm64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page