Skip to main content

Drop-in memory optimizer for PyTorch training. Reduce VRAM significantly with one line of code.

Project description

MemScale

Drop-in memory optimizer for PyTorch training. Reduce VRAM up to 88% with 1 line of code.

PyPI version Python 3.10+ License


The problem

Training large models on GPUs hits a wall: VRAM.

  • BERT-Large with batch 16 → 17.6 GB on RTX 3090
  • 1.5B parameter model → out of memory on single 24GB GPU
  • DeepSpeed ZeRO setup → 2 weeks of configuration

MemScale solves this. Wrap your model in 1 line, get up to 88% VRAM reduction, no code changes.

Real benchmarks (validated on RTX 3090 24GB)

Model Params Baseline MemScale Reduction
BERT-Base 85M 6.39 GB 1.87 GB 70.8%
BERT-Large 302M 17.57 GB 2.11 GB 88.0%
GPT-2 Medium 302M 19.02 GB 7.16 GB 62.4%
GPT-2 Large 708M 21.78 GB 4.88 GB 77.6%
1.3B model 1.3B OOM 8.86 GB Enables training
GPT-2 XL 1.5B OOM 12.72 GB Enables training

Comparison: PyTorch native checkpointing achieves 70% on the same workloads. MemScale matches or exceeds it with the right configuration, and enables training models that PyTorch alone cannot fit.

Quick start

pip install memscale
pip install bitsandbytes  # optional, for additional 8-bit Adam savings
import memscale
from transformers import Trainer, TrainingArguments

trainer = Trainer(
    model=model,
    args=TrainingArguments(per_device_train_batch_size=16),
    train_dataset=dataset,
)

# Add this one line:
trainer = memscale.wrap(trainer)

trainer.train()  # Up to 88% less VRAM, same speed

That's it. MemScale automatically:

  1. Profiles your model's memory usage per layer
  2. Decides which optimization technique fits each layer best
  3. Applies boundary checkpointing, 8-bit Adam, mixed precision
  4. Reports memory savings and throughput in real time

No API key required. Library works fully offline. Anonymous telemetry is enabled by default to help improve the decision engine — opt out anytime with memscale.disable_telemetry() or MEMSCALE_TELEMETRY=0. See Telemetry below for what's collected.

Maximum reduction (combined techniques)

For maximum savings, enable all techniques:

import torch
from memscale import Config, OptimizationMode
from memscale.phase_f import apply_all_optimizations

model = YourModel()
optimizer = torch.optim.AdamW(model.parameters())

config = Config(
    mode=OptimizationMode.AGGRESSIVE,
    use_8bit_optimizer=True,    # bitsandbytes 8-bit Adam
    use_mixed_precision=True,   # BF16 on Ampere+, FP16 fallback
)

# One call applies all techniques
model, optimizer = apply_all_optimizations(model, optimizer, config)

# Train normally
for batch in dataloader:
    loss = model(batch).loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

This stack achieved 88.0% reduction on BERT-Large in our benchmarks.

How it works

MemScale combines proven memory optimization techniques and chooses what fits each layer:

Technique Saves When applied
Boundary checkpointing ~70% (activations) Transformer blocks (BertLayer, GPT2Block, TransformerEncoderLayer, ViTLayer, etc.)
8-bit Adam (bitsandbytes) ~75% (optimizer state) When use_8bit_optimizer=True and bitsandbytes installed
Mixed precision (BF16/FP16) ~50% (params/activations) When use_mixed_precision=True on Ampere+ GPUs
CPU offload Variable Large layers when checkpointing not enough

The decision engine analyzes your model and picks the right technique per layer — you don't need to configure individual layers.

HuggingFace integration

For HuggingFace autoregressive models (GPT-2, Llama, Mistral, T5), MemScale automatically disables config.use_cache when checkpointing is enabled. This prevents the CheckpointError that occurs when KV-cache concatenation conflicts with backward recompute. No code changes needed — just memscale.wrap().

Multi-GPU support

Multi-GPU training works via standard PyTorch DDP. MemScale's per-GPU optimizations apply on each GPU:

torchrun --nproc_per_node=2 your_training_script.py
import memscale
import torch.nn.parallel as parallel

model = YourModel().to(local_rank)
model, optimizer = apply_all_optimizations(model, optimizer, config)
model = parallel.DistributedDataParallel(model, device_ids=[local_rank])

# Train normally - 87% per-GPU reduction with 2x throughput

Validated on 2x RTX 3090: 1.69 GB per GPU (vs 13 GB baseline single-GPU).

Distributed sharding (research preview)

memscale.distributed provides ZeRO-3 inspired parameter and optimizer sharding building blocks. Full integration with model forward/backward hooks is planned for v1.1. For production multi-GPU training requiring 95%+ reduction today, FSDP or DeepSpeed remain the recommended choice.

Usage modes

HuggingFace Trainer

import memscale
trainer = memscale.wrap(your_hf_trainer)
trainer.train()

PyTorch Lightning

from lightning import Trainer
from memscale.integrations.lightning import MemScaleLightningCallback

trainer = Trainer(
    callbacks=[MemScaleLightningCallback()],
    max_epochs=10,
)
trainer.fit(model, dataloader)

Custom training loop

import memscale

with memscale.optimize(model, optimizer) as ms:
    for batch in dataloader:
        loss = model(batch).loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Configuration

Most users don't need this. Defaults work for 90% of cases.

from memscale import wrap, Config, OptimizationMode

config = Config(
    mode=OptimizationMode.AGGRESSIVE,  # or BALANCED (default), CONSERVATIVE
    enable_checkpointing=True,
    enable_offloading=True,
    use_8bit_optimizer=False,    # set True for max reduction
    use_mixed_precision=False,   # set True for max reduction
    target_gpu_utilization=0.85,
)

trainer = wrap(trainer, config=config)

Cost attribution

Track how much money MemScale saves on cloud GPU bills:

from memscale.cost_attribution import CostTracker, estimate_savings

# Quick estimate
report = estimate_savings(
    baseline_vram_gb=70.0,
    memscale_vram_gb=35.0,
    training_hours=10.0,
    gpu_type='A40',
    baseline_gpu_type='A100 80GB',  # GPU you'd need WITHOUT MemScale
)
print(report)
# Baseline (A100 80GB): $24.90
# MemScale (A40):       $15.30
# Savings:              $9.60 (38.6%)

Built-in pricing for 16 GPU types (V100, A100, H100, RTX series, AMD MI300X, etc.) plus auto-inference of the cheapest GPU sufficient for your workload.

OOM prediction

Catch out-of-memory before training starts:

from memscale import OOMPredictor

predictor = OOMPredictor(model_params_bytes=2_000_000_000)  # 2 GB params
risk = predictor.predict(batch_size=16, sequence_length=2048, optimizer='adamw')

if risk.level == 'CRITICAL':
    print(f"⚠️ {risk.message}")
    print(f"Recommendations: {risk.recommendations}")

Telemetry

MemScale ships with anonymous telemetry enabled by default to improve the decision engine across diverse hardware and workloads. To opt out:

import memscale
memscale.disable_telemetry()

Or via environment variable (set before importing MemScale):

export MEMSCALE_TELEMETRY=0

What's collected (~1 KB per training run)

  • Anonymous client ID (random UUID, stored locally at ~/.memscale/client_id)
  • Library version, Python version, PyTorch version, OS
  • Hardware: GPU model, VRAM, CUDA version, number of GPUs
  • Model architecture: layer types, parameter count (no weights)
  • Optimization outcome: techniques applied, memory saved, throughput overhead

What's NEVER collected

  • ❌ Model weights or training data
  • ❌ Code or scripts
  • ❌ File paths, hostnames, IP addresses
  • ❌ Email or any identifying information
  • ❌ Layer-level activations or gradients

Telemetry is fire-and-forget (silent failure on network error), sent over HTTPS to api.memscale.id/v1/telemetry. See our privacy policy for details.

Re-enable after opting out

memscale.enable_telemetry()

Or:

export MEMSCALE_TELEMETRY=1

Compatibility

Component Min Version Tested
Python 3.10 3.10, 3.11, 3.12
PyTorch 2.1 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9
CUDA 11.8 11.8, 12.1, 12.4, 12.8
GPU Compute capability 7.0+ V100, A100, H100, RTX 3090/4090
BF16 mixed precision Compute capability 8.0+ A100, H100, RTX 3090/4090
OS Linux/macOS/Windows Ubuntu 20.04+, macOS 14+ (arm64), Windows 11

AMD GPU support (ROCm) coming in a future release.

FAQ

Q: Does MemScale change my training results? The activation checkpointing and DDP techniques are mathematically lossless. BF16/FP16 mixed precision introduces small numerical differences — same as standard PyTorch AMP.

Q: How does this compare to DeepSpeed and FSDP? DeepSpeed and FSDP are powerful but require significant configuration and distributed training expertise. MemScale's value is plug-and-play: 1-line wrap with auto-detection. For 95%+ reduction in production multi-GPU setups, DeepSpeed ZeRO-3 is more mature. For single-GPU and DDP workloads, MemScale is competitive and easier to use.

Q: Will this slow down my training? Activation checkpointing adds 20-30% compute overhead (the standard tradeoff). 8-bit Adam adds ~2-5%. Net effect: training is slower per step, but you can use larger batches (better hardware utilization), so end-to-end time often improves.

Q: What if my model has custom architecture? The decision engine handles standard transformers (PyTorch native, HuggingFace BERT/GPT2/Llama/Mistral, vision transformers) automatically. Custom architectures fall back to per-module heuristics. Both are tested.

Q: Why "up to 88%" instead of a flat number? Reduction depends on model architecture, batch size, sequence length, and which techniques you enable. Our benchmarks show 62-88% on standard transformers. Smaller and older models show less; large modern models with long sequences see the most savings.

Q: Is MemScale open source? Source code is currently proprietary. PyPI distribution is public (free to install and use). We may open source later based on community feedback. See memscale.id for licensing.

Roadmap

  • v1.0.4 (current): 5 medium bug fixes (seq_len awareness, dtype-aware param count, tiling outputs, GPU downgrade cost, estimate_training_memory)
  • v1.1 (Q3 2026): Stability release, multi-GPU verified, AMD GPU (ROCm), full ZeRO-3 integration
  • v1.2 (Q4 2026): ML-based decision policy trained on telemetry data
  • v2.0 (2027): Multi-framework support (JAX, TensorFlow) + MemScale Serve (inference)

Architecture

MemScale's optimization happens in stages:

  1. Profiling: Static analysis with empirical fallback for dynamic models
  2. Decision engine: Per-layer technique selection based on memory profile, hardware budget, and configuration
  3. Execution: Apply chosen techniques via PyTorch hooks
  4. Observation: Track memory and throughput, report to user

Reporting Issues

For bug reports, please include:

  1. Minimal reproducible example
  2. Hardware (GPU model, VRAM)
  3. PyTorch version
  4. Output of memscale.profile_model(model) if relevant

Email: team@memscale.id

License

Proprietary. Full terms: contact team@memscale.id or visit memscale.id.

Citation

If you use MemScale in your research, please cite:

@software{memscale2026,
  title={MemScale: Drop-in Memory Optimization for PyTorch Training},
  author={MemScale Team},
  year={2026},
  url={https://memscale.id}
}

Built for ML practitioners. Questions? team@memscale.id

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

memscale-1.0.4-cp312-cp312-win_amd64.whl (1.4 MB view details)

Uploaded CPython 3.12Windows x86-64

memscale-1.0.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.7 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

memscale-1.0.4-cp312-cp312-macosx_11_0_arm64.whl (1.4 MB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

memscale-1.0.4-cp311-cp311-win_amd64.whl (1.4 MB view details)

Uploaded CPython 3.11Windows x86-64

memscale-1.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

memscale-1.0.4-cp311-cp311-macosx_11_0_arm64.whl (1.4 MB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

memscale-1.0.4-cp310-cp310-win_amd64.whl (1.4 MB view details)

Uploaded CPython 3.10Windows x86-64

memscale-1.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.6 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

memscale-1.0.4-cp310-cp310-macosx_11_0_arm64.whl (1.4 MB view details)

Uploaded CPython 3.10macOS 11.0+ ARM64

File details

Details for the file memscale-1.0.4-cp312-cp312-win_amd64.whl.

File metadata

  • Download URL: memscale-1.0.4-cp312-cp312-win_amd64.whl
  • Upload date:
  • Size: 1.4 MB
  • Tags: CPython 3.12, Windows x86-64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for memscale-1.0.4-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 f01066b6329ebbba968390d10109eeb2130b77cdaa48ba10d721b93f02146b63
MD5 f14e9c12185fa5cc0ef3d855fcf15b4e
BLAKE2b-256 9b21bead0cd5f1dbc86b29e3cf2d871e090d243a37d2ae01e85796c65547e112

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.4-cp312-cp312-win_amd64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.0.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for memscale-1.0.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 6e04c6267fc015e539701ebc05778babeebec369c3717143b5f9cd4f39928dbf
MD5 1fb2d9374b0a4d3857b64ed0557e783d
BLAKE2b-256 55018f6c56917f2637a27b87a3d353b79da4350bf387144468523681875d0d7b

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.0.4-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for memscale-1.0.4-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 51e164c975f42548105f7c4d6f0b9fc4bcf97a62f4d7c7d3816acf655b8ee728
MD5 caa8915bd7abb75768d33eb56bd9e3c3
BLAKE2b-256 799f0a4c10a70c16a8a16287e760b337eb5c81c8d9167bbbb1f8f33d7fba2214

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.4-cp312-cp312-macosx_11_0_arm64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.0.4-cp311-cp311-win_amd64.whl.

File metadata

  • Download URL: memscale-1.0.4-cp311-cp311-win_amd64.whl
  • Upload date:
  • Size: 1.4 MB
  • Tags: CPython 3.11, Windows x86-64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for memscale-1.0.4-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 b52a26e4ff4e37e5219c3f1c974c0f05c654a2dcfec21c5071131a86aab14752
MD5 61d29ddd8b96f9c6b67c1e156e2e2430
BLAKE2b-256 19b71779ee01ea49002a10e80406d9aca39b50dfeef953f4bccb17bfc16279bd

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.4-cp311-cp311-win_amd64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for memscale-1.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 3e824f47bd4344e156dcc5065dd7abddc21fa48215084ff1c26961de9f970078
MD5 bb099e4d14d518b1931765521670fd77
BLAKE2b-256 b83288a0bf4beff2e151f7fc5ec0ea9b780e8753456ec790edab928089898a0c

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.0.4-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for memscale-1.0.4-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 5f46f3f71864adfc63c4c1bfed43afe18d4dc2f530bfb5196fc87e2d18eca85e
MD5 2296c634398bca25346b17cdc57a5e9d
BLAKE2b-256 ff44f2193647b59688426faeef6fc3f0d698e6bc55f3dfd9388b5038190b2d4d

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.4-cp311-cp311-macosx_11_0_arm64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.0.4-cp310-cp310-win_amd64.whl.

File metadata

  • Download URL: memscale-1.0.4-cp310-cp310-win_amd64.whl
  • Upload date:
  • Size: 1.4 MB
  • Tags: CPython 3.10, Windows x86-64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for memscale-1.0.4-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 380ba4a20640e407f6787d40903033ce65ca8fc5b74ea7e517a62b17a020740f
MD5 519460c0d7faaf23e6eb2ae3efd7d45a
BLAKE2b-256 cb973db3da143a077abd98cf61be7584accfadff3a6553223ce9ef1825fb9efc

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.4-cp310-cp310-win_amd64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for memscale-1.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 637e46a1ab51c915ecda3efb524e5656a7b1040152af7f2217471ba01ce2580f
MD5 f709a62170cbe691eb19f885a2d54345
BLAKE2b-256 8f25f021a0d8e2a0c7430d3a86b5eaff87c9142457ae50f30e1f582738085716

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.0.4-cp310-cp310-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for memscale-1.0.4-cp310-cp310-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 fbca345418e14c9ef96965c796216b07196f1950fd7bedd661ba743c77f40d64
MD5 d7effdcea5b88b17e9563c0e7a578dc3
BLAKE2b-256 b841f8e845ecb8c87c8e652d90fc2ee71f467a5ff2b8283e15e28cffa2c4ba3f

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.0.4-cp310-cp310-macosx_11_0_arm64.whl:

Publisher: build-wheels.yml on MrGinkaku/MemScale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page