Skip to main content

Drop-in memory optimizer for PyTorch training. Reduce VRAM significantly with one line of code.

Project description

MemScale

Drop-in memory optimizer for PyTorch training. Cut VRAM up to ~76% — and train models that otherwise don't fit — with 1 line of code.

PyPI version Python 3.10+ License


The problem

Training large models on GPUs hits a wall: VRAM.

  • GPT-2 Large training → 14.9 GB on a 24 GB RTX 3090, with little headroom
  • 1.5B parameter model → out of memory on single 24GB GPU
  • DeepSpeed ZeRO setup → 2 weeks of configuration

MemScale solves this. Wrap your model in 1 line, cut VRAM up to ~76%, and train models that otherwise run out of memory — no code changes.

Benchmarks

Validated on RTX 3090 24GB (PyTorch 2.12, CUDA 13) Reproducible: python -m memscale.benchmarks --output results.json

Model Params Batch × Seq Baseline MemScale Reduction
BERT-Base 110M 16 × 128 3.14 GB 0.84 GB 73.1%
BERT-Large 340M 16 × 128 7.60 GB 2.02 GB 73.4%
GPT-2 Medium 355M 4 × 512 10.87 GB 2.61 GB 76.0%
GPT-2 Large 774M 2 × 512 14.87 GB 4.68 GB 68.5%
GPT-2 XL 1.5B 1 × 512 OOM 9.25 GB enables training

Configuration: AGGRESSIVE mode (8-bit Adam + BF16 + checkpointing). Reduction % scales with workload size — larger batches and longer sequences typically yield higher percentages. See BENCHMARK_REPORT.md for methodology details.

Quick start

pip install memscale
pip install bitsandbytes  # optional, for additional 8-bit Adam savings
import memscale
from transformers import Trainer, TrainingArguments

trainer = Trainer(
    model=model,
    args=TrainingArguments(per_device_train_batch_size=16),
    train_dataset=dataset,
)

# Add this one line:
trainer = memscale.wrap(trainer)

trainer.train()  # Up to ~76% less VRAM, same speed

That's it. MemScale automatically:

  1. Profiles your model's memory usage per layer
  2. Decides which optimization technique fits each layer best
  3. Applies boundary checkpointing, 8-bit Adam, mixed precision
  4. Reports memory savings and throughput in real time

No API key required. Library works fully offline. Anonymous telemetry is enabled by default to help improve the decision engine — opt out anytime with memscale.disable_telemetry() or MEMSCALE_TELEMETRY=0. See Telemetry below for what's collected.

Experimental: async CPU offload (v1.1+)

v1.1 adds an experimental async CPU offload engine. It overlaps GPU→CPU transfers with compute using tier-aware CUDA streams, while keeping the same VRAM savings as the default sync path.

import memscale

model = memscale.wrap(model, async_offload=True)
# Same VRAM savings, async transfer.
# Note: speedup is not yet benchmarked — see the v1.2 roadmap.

async_offload defaults to False; existing users are unaffected. Enabling it emits an ExperimentalFeatureWarning. Numerical equivalence with the sync path is validated (rtol 1e-6 / atol 1e-7 on RTX 3090); the training-loop speedup is deferred to a dedicated v1.2 benchmarking phase. For production, use the default sync offload. See the async offload user guide for details.

GPU Tier Examples async_offload support
High RTX 30xx+, A100, H100 Full (dual-stream)
Low RTX 20xx, GTX 10xx Single-stream
CPU No CUDA Sync only (no async benefit)

Maximum reduction (combined techniques)

For maximum savings, enable all techniques:

import torch
from memscale import Config, OptimizationMode
from memscale.phase_f import apply_all_optimizations

model = YourModel()
optimizer = torch.optim.AdamW(model.parameters())

config = Config(
    mode=OptimizationMode.AGGRESSIVE,
    use_8bit_optimizer=True,    # bitsandbytes 8-bit Adam
    use_mixed_precision=True,   # BF16 on Ampere+, FP16 fallback
)

# One call applies all techniques
model, optimizer = apply_all_optimizations(model, optimizer, config)

# Train normally
for batch in dataloader:
    loss = model(batch).loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

This stack achieved 73.4% reduction on BERT-Large (7.60 GB → 2.02 GB) in the v1.1 benchmarks — see BENCHMARK_REPORT.md.

How it works

MemScale combines proven memory optimization techniques and chooses what fits each layer:

Technique Saves When applied
Boundary checkpointing ~70% (activations) Transformer blocks (BertLayer, GPT2Block, TransformerEncoderLayer, ViTLayer, etc.)
8-bit Adam (bitsandbytes) ~75% (optimizer state) When use_8bit_optimizer=True and bitsandbytes installed
Mixed precision (BF16/FP16) ~50% (params/activations) When use_mixed_precision=True on Ampere+ GPUs
CPU offload (sync + experimental async) Variable Large layers when checkpointing not enough

The decision engine analyzes your model and picks the right technique per layer — you don't need to configure individual layers.

HuggingFace integration

For HuggingFace autoregressive models (GPT-2, Llama, Mistral, T5), MemScale automatically disables config.use_cache when checkpointing is enabled. This prevents the CheckpointError that occurs when KV-cache concatenation conflicts with backward recompute. No code changes needed — just memscale.wrap().

Multi-GPU support

Multi-GPU training works via standard PyTorch DDP. MemScale's per-GPU optimizations apply on each GPU:

torchrun --nproc_per_node=2 your_training_script.py
import memscale
import torch.nn.parallel as parallel

model = YourModel().to(local_rank)
model, optimizer = apply_all_optimizations(model, optimizer, config)
model = parallel.DistributedDataParallel(model, device_ids=[local_rank])

# Train normally - 87% per-GPU reduction with 2x throughput

Validated on 2x RTX 3090: 1.69 GB per GPU (vs 13 GB baseline single-GPU).

Distributed sharding (research preview)

memscale.distributed provides ZeRO-3 inspired parameter and optimizer sharding building blocks. Full integration with model forward/backward hooks is planned for v1.1. For production multi-GPU training requiring 95%+ reduction today, FSDP or DeepSpeed remain the recommended choice.

Usage modes

HuggingFace Trainer

import memscale
trainer = memscale.wrap(your_hf_trainer)
trainer.train()

PyTorch Lightning

from lightning import Trainer
from memscale.integrations.lightning import MemScaleLightningCallback

trainer = Trainer(
    callbacks=[MemScaleLightningCallback()],
    max_epochs=10,
)
trainer.fit(model, dataloader)

Custom training loop

import memscale

with memscale.optimize(model, optimizer) as ms:
    for batch in dataloader:
        loss = model(batch).loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Configuration

Most users don't need this. Defaults work for 90% of cases.

from memscale import wrap, Config, OptimizationMode

config = Config(
    mode=OptimizationMode.AGGRESSIVE,  # or BALANCED (default), CONSERVATIVE
    enable_checkpointing=True,
    enable_offloading=True,
    use_8bit_optimizer=False,    # set True for max reduction
    use_mixed_precision=False,   # set True for max reduction
    target_gpu_utilization=0.85,
)

trainer = wrap(trainer, config=config)

Cost attribution

Track how much money MemScale saves on cloud GPU bills:

from memscale.cost_attribution import CostTracker, estimate_savings

# Quick estimate
report = estimate_savings(
    baseline_vram_gb=70.0,
    memscale_vram_gb=35.0,
    training_hours=10.0,
    gpu_type='A40',
    baseline_gpu_type='A100 80GB',  # GPU you'd need WITHOUT MemScale
)
print(report)
# Baseline (A100 80GB): $24.90
# MemScale (A40):       $15.30
# Savings:              $9.60 (38.6%)

Built-in pricing for 16 GPU types (V100, A100, H100, RTX series, AMD MI300X, etc.) plus auto-inference of the cheapest GPU sufficient for your workload.

OOM prediction

Catch out-of-memory before training starts:

from memscale import OOMPredictor

predictor = OOMPredictor(model_params_bytes=2_000_000_000)  # 2 GB params
risk = predictor.predict(batch_size=16, sequence_length=2048, optimizer='adamw')

if risk.level == 'CRITICAL':
    print(f"⚠️ {risk.message}")
    print(f"Recommendations: {risk.recommendations}")

Telemetry

MemScale ships with anonymous telemetry enabled by default to improve the decision engine across diverse hardware and workloads. To opt out:

import memscale
memscale.disable_telemetry()

Or via environment variable (set before importing MemScale):

export MEMSCALE_TELEMETRY=0

What's collected (~1 KB per training run)

  • Anonymous client ID (random UUID, stored locally at ~/.memscale/client_id)
  • Library version, Python version, PyTorch version, OS
  • Hardware: GPU model, VRAM, CUDA version, number of GPUs
  • Model architecture: layer types, parameter count (no weights)
  • Optimization outcome: techniques applied, memory saved, throughput overhead

What's NEVER collected

  • ❌ Model weights or training data
  • ❌ Code or scripts
  • ❌ File paths, hostnames, IP addresses
  • ❌ Email or any identifying information
  • ❌ Layer-level activations or gradients

Telemetry is fire-and-forget (silent failure on network error), sent over HTTPS to api.memscale.id/v1/telemetry. See our privacy policy for details.

Re-enable after opting out

memscale.enable_telemetry()

Or:

export MEMSCALE_TELEMETRY=1

Compatibility

Component Min Version Tested
Python 3.10 3.10, 3.11, 3.12
PyTorch 2.1 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9
CUDA 11.8 11.8, 12.1, 12.4, 12.8
GPU Compute capability 7.0+ V100, A100, H100, RTX 3090/4090
BF16 mixed precision Compute capability 8.0+ A100, H100, RTX 3090/4090
OS Linux/macOS/Windows Ubuntu 20.04+, macOS 14+ (arm64), Windows 11

AMD GPU support (ROCm) coming in a future release.

FAQ

Q: Does MemScale change my training results? The activation checkpointing and DDP techniques are mathematically lossless. BF16/FP16 mixed precision introduces small numerical differences — same as standard PyTorch AMP.

Q: How does this compare to DeepSpeed and FSDP? DeepSpeed and FSDP are powerful but require significant configuration and distributed training expertise. MemScale's value is plug-and-play: 1-line wrap with auto-detection. For 95%+ reduction in production multi-GPU setups, DeepSpeed ZeRO-3 is more mature. For single-GPU and DDP workloads, MemScale is competitive and easier to use.

Q: Will this slow down my training? Activation checkpointing adds 20-30% compute overhead (the standard tradeoff). 8-bit Adam adds ~2-5%. Net effect: training is slower per step, but you can use larger batches (better hardware utilization), so end-to-end time often improves.

Q: What if my model has custom architecture? The decision engine handles standard transformers (PyTorch native, HuggingFace BERT/GPT2/Llama/Mistral, vision transformers) automatically. Custom architectures fall back to per-module heuristics. Both are tested.

Q: Why a range instead of a flat number? Reduction depends on model architecture, batch size, sequence length, and which techniques you enable. The v1.1 benchmarks show roughly 68-76% on standard transformers at the default workload, and MemScale enables training models that otherwise run out of memory. Larger batches and longer sequences typically yield higher percentages.

Q: Is MemScale open source? Source code is currently proprietary. PyPI distribution is public (free to install and use). We may open source later based on community feedback. See memscale.id for licensing.

Roadmap

  • v1.0.4 (current): 5 medium bug fixes (seq_len awareness, dtype-aware param count, tiling outputs, GPU downgrade cost, estimate_training_memory)
  • v1.1 (Q3 2026): Reproducible benchmark CLI, bug fixes (F-1…F-4), honest cost attribution, and experimental async CPU offload (opt-in, correctness-validated; speedup deferred to v1.2)
  • v1.2 (Q4 2026): Async offload speedup benchmarking + older-GPU validation, tensor splitting, ML-based decision policy trained on telemetry data
  • v2.0 (2027): Multi-framework support (JAX, TensorFlow) + MemScale Serve (inference)

Architecture

MemScale's optimization happens in stages:

  1. Profiling: Static analysis with empirical fallback for dynamic models
  2. Decision engine: Per-layer technique selection based on memory profile, hardware budget, and configuration
  3. Execution: Apply chosen techniques via PyTorch hooks
  4. Observation: Track memory and throughput, report to user

Reporting Issues

For bug reports, please include:

  1. Minimal reproducible example
  2. Hardware (GPU model, VRAM)
  3. PyTorch version
  4. Output of memscale.profile_model(model) if relevant

Email: team@memscale.id

License

Proprietary. Full terms: contact team@memscale.id or visit memscale.id.

Citation

If you use MemScale in your research, please cite:

@software{memscale2026,
  title={MemScale: Drop-in Memory Optimization for PyTorch Training},
  author={MemScale Team},
  year={2026},
  url={https://memscale.id}
}

Built for ML practitioners. Questions? team@memscale.id

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

memscale-1.2.0-cp312-cp312-win_amd64.whl (2.2 MB view details)

Uploaded CPython 3.12Windows x86-64

memscale-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.8 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

memscale-1.2.0-cp312-cp312-macosx_11_0_arm64.whl (2.3 MB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

memscale-1.2.0-cp311-cp311-win_amd64.whl (2.2 MB view details)

Uploaded CPython 3.11Windows x86-64

memscale-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

memscale-1.2.0-cp311-cp311-macosx_11_0_arm64.whl (2.3 MB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

memscale-1.2.0-cp310-cp310-win_amd64.whl (2.2 MB view details)

Uploaded CPython 3.10Windows x86-64

memscale-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

memscale-1.2.0-cp310-cp310-macosx_11_0_arm64.whl (2.4 MB view details)

Uploaded CPython 3.10macOS 11.0+ ARM64

File details

Details for the file memscale-1.2.0-cp312-cp312-win_amd64.whl.

File metadata

  • Download URL: memscale-1.2.0-cp312-cp312-win_amd64.whl
  • Upload date:
  • Size: 2.2 MB
  • Tags: CPython 3.12, Windows x86-64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for memscale-1.2.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 8f8974adc2ecf9ae2dacb7b7eecd04cecb3e5a0d160bb7adeed306c64b647643
MD5 5c46b571954181b1b0452f6d65a5a6d5
BLAKE2b-256 f076e97e5999b0bf76abe0ffc76c56078f54dd98e587332e451168791464774f

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.2.0-cp312-cp312-win_amd64.whl:

Publisher: build-wheels.yml on memscale/Memscale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for memscale-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 64b26a636f1a89fd71789614d4d185eb314b5cbb88c16cd4cd0b7d90c971f77d
MD5 89459cbaae1fba57e6c657d426a80fe1
BLAKE2b-256 2414343979b149182755638b7e910ed009692e95540ced19314862457ff2a86e

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: build-wheels.yml on memscale/Memscale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.2.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for memscale-1.2.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 322cf609f61ac8dcdeb61add5dd1c55ab52ce0760651efd8edd9754fc32d7ab5
MD5 fe19573dedc3b4ba4f04b9f28e05ce4b
BLAKE2b-256 c24cf1619fd6a3c7267666aa7240516a0f3e6121b6fe830453971c8263ff362a

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.2.0-cp312-cp312-macosx_11_0_arm64.whl:

Publisher: build-wheels.yml on memscale/Memscale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.2.0-cp311-cp311-win_amd64.whl.

File metadata

  • Download URL: memscale-1.2.0-cp311-cp311-win_amd64.whl
  • Upload date:
  • Size: 2.2 MB
  • Tags: CPython 3.11, Windows x86-64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for memscale-1.2.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 73ac091c7e3914fd31b0d3d85c840ac0a5681e4b82abaa9dcdbc860cbef1a55b
MD5 f0de58b47832047d9061d908ea38722a
BLAKE2b-256 ec1809097e8a9fff953475ad0ee50234a1e7490a39ecf91efec1b3123f20b065

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.2.0-cp311-cp311-win_amd64.whl:

Publisher: build-wheels.yml on memscale/Memscale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for memscale-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 f450bddfe6d56d6221a57496362e967cb9b6d1acf2ba5789c8f4c6fc90d0a613
MD5 a407d852319d94c12e6d490ad8320d8d
BLAKE2b-256 c688dc4432dbe16e3d414da36e250d078c6e609d99f42a83c206b7867bf134b3

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: build-wheels.yml on memscale/Memscale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.2.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for memscale-1.2.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 d59e621ef219ac96fadd05700ff4033a20f81a33cc0cd7eeac9032eff9d4373f
MD5 b26e01b0ecd2753b0ffe959f2cf45323
BLAKE2b-256 15ff14d569da6b5399e506b175085083923af2020e9edd21eaf66162bd9a4c96

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.2.0-cp311-cp311-macosx_11_0_arm64.whl:

Publisher: build-wheels.yml on memscale/Memscale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.2.0-cp310-cp310-win_amd64.whl.

File metadata

  • Download URL: memscale-1.2.0-cp310-cp310-win_amd64.whl
  • Upload date:
  • Size: 2.2 MB
  • Tags: CPython 3.10, Windows x86-64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for memscale-1.2.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 5df4ae80a87d41c3d47171d1ac0079814c80fd6c62f5e092cb1f020c74e2f405
MD5 9df39d8afcfdfd68ee62474847f4ed4d
BLAKE2b-256 15f41413a2546e3b51ee38357f3be40079c09bb245d5203427e5eb37d95352d1

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.2.0-cp310-cp310-win_amd64.whl:

Publisher: build-wheels.yml on memscale/Memscale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for memscale-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 8b0adb7c24c9a4b4e22830e52c584308331eb6eb634f61ade264255567502111
MD5 1c3d26d8a9eddcb4ed94efb6f881b0cd
BLAKE2b-256 e8bc04374fd0002d66fc95730d236cff85e9f6f8a013826c7ce3c1a04ea56864

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: build-wheels.yml on memscale/Memscale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file memscale-1.2.0-cp310-cp310-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for memscale-1.2.0-cp310-cp310-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 d442f22531249885502fbcb806c3d2290f15ac9b9ca18e694acad6aa58921ebf
MD5 1dc19d8bec7cc17ab46d0197e5782879
BLAKE2b-256 486fa81722d6c5a5e254e565aa9e6c87e10abc798a9fa6835f90d4c3e40d6c2b

See more details on using hashes here.

Provenance

The following attestation bundles were made for memscale-1.2.0-cp310-cp310-macosx_11_0_arm64.whl:

Publisher: build-wheels.yml on memscale/Memscale

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page