Skip to main content

Cross-platform memory guard for ML training. Prevents OOM crashes on Apple Silicon, CUDA, and CPU by proactive estimation and runtime monitoring.

Project description

memory-guard

Cross-platform memory guard for ML training. Prevents OOM crashes on Apple Silicon, CUDA, and CPU.

No more frozen Macs. No more CUDA out of memory. Just training that works.

pip install memory-guard

The Problem

  • Apple Silicon: No OOM exception exists. When you exceed memory, macOS silently swaps to disk, your Mac freezes for minutes, and eventually the OS kills your process. Every ML practitioner on Mac has experienced this.
  • CUDA: torch.cuda.OutOfMemoryError crashes your training run. You restart, guess a smaller batch size, and pray.
  • Containers: cgroups silently kill your process with no warning when you hit the memory limit.

Existing solutions (PyTorch Lightning BatchSizeFinder, HuggingFace accelerate) are CUDA-only and reactive — they catch OOM exceptions that don't exist on Apple Silicon.

The Solution

memory-guard is proactive, not reactive. It estimates peak memory before training starts and auto-adjusts your config to fit. During training, it monitors memory pressure and dynamically downgrades if needed.

from memory_guard import MemoryGuard

guard = MemoryGuard.auto()

# Pre-flight: estimate memory and auto-downgrade config
safe = guard.preflight(
    model_params=9_000_000_000,  # 9B parameter model
    model_bits=4,                # 4-bit quantized
    hidden_dim=4096,
    num_heads=32,
    num_layers=32,
    batch_size=4,
    seq_length=2048,
    lora_rank=32,
    lora_layers=16,
)

print(safe)
# SafeConfig (FITS):
#   batch_size:       2          <- auto-reduced from 4
#   grad_checkpoint:  True       <- auto-enabled
#   grad_accumulation:4          <- compensates for smaller batch
#   estimated memory: 3835 MB
#   budget:           4643 MB

# Runtime monitoring: polls memory pressure every 5s
with guard.monitor(safe.batch_size) as mon:
    for step in range(1000):
        # Batch size may decrease mid-training if pressure rises
        train_step(batch_size=mon.current_batch_size)

Features

Feature Apple Silicon CUDA Linux CPU Windows
Proactive memory estimation Yes Yes Yes Yes
Auto-downgrade config Yes Yes Yes Yes
Runtime pressure monitoring Yes (Mach kernel + MLX Metal) Yes (torch.cuda) Yes (PSI, cgroups) Yes (kernel32)
MLX Metal ground-truth Yes (mx.metal.get_active_memory) N/A N/A N/A
OOM catch & retry N/A (no OOM on Metal) Yes N/A N/A
Container-aware (cgroups v1/v2) N/A Yes Yes N/A
Auto-calibration Yes Yes Yes Yes
FlashAttention-aware Yes Yes Yes Yes
GQA / MoE / Multi-modal Yes Yes Yes Yes

How It Works

1. Proactive Estimation

Calculates peak memory from model architecture, accounting for:

  • Per-projection LoRA input buffers (Q, K, V, O)
  • FlashAttention O(n) vs standard O(n^2) attention scores
  • GQA-aware KV cache (uses num_kv_heads, not num_heads)
  • MoE routing buffers and active expert activations
  • Optimizer states (Adam 3x, SGD 2x, Adafactor 1.5x)
  • MLX lazy evaluation discount (20% reduction on Apple Silicon)
  • Framework overhead (25% proportional + 400MB fixed runtime cost)

With gradient checkpointing, activation memory drops to sqrt(layers).

2. Auto-Downgrade (quality-preserving order)

When estimate exceeds budget (available × 80%):

  1. Enable gradient checkpointing (free quality, ~40% activation savings)
  2. Halve batch size (compensate with gradient accumulation)
  3. Halve sequence length
  4. Halve LoRA rank
  5. Halve LoRA layers

3. Runtime Monitoring

Background thread polls memory pressure every 5 seconds:

  • Apple Silicon: mx.metal.get_active_memory() (ground-truth from Metal allocator), with kern.memorystatus_level as fallback. Detects the monotonic memory growth pattern from mlx-examples#1262.
  • CUDA: torch.cuda.memory_allocated() vs total VRAM
  • Linux: /proc/pressure/memory (PSI), cgroup-aware (memory.high preferred over memory.max)
  • Windows: GlobalMemoryStatusEx

When pressure exceeds 85%, batch size is halved mid-training.

4. Auto-Calibration

After each training run, the actual peak memory (from mx.metal.get_peak_memory() or torch.cuda.max_memory_allocated()) is recorded alongside the formula estimate. After 3+ runs, a median correction factor is applied to future estimates, narrowing the gap between predicted and actual memory usage over time.

Framework Integration

With mlx_lm (Apple Silicon)

import mlx.optimizers as optim
from memory_guard import MemoryGuard
from mlx_lm import load
from mlx_lm.tuner.trainer import train, TrainingArgs
from mlx_lm.tuner.utils import linear_to_lora_layers

guard = MemoryGuard.auto()
model, tokenizer = load("mlx-community/Qwen3.5-9B-MLX-4bit")

safe = guard.preflight(
    model_params=9e9, model_bits=4,
    hidden_dim=4096, num_heads=32, num_layers=32,
    batch_size=4, seq_length=2048,
    lora_rank=32, lora_layers=16,
)

model.freeze()
linear_to_lora_layers(
    model, safe.lora_layers,
    {"rank": safe.lora_rank, "scale": 20.0, "dropout": 0.0},
)
optimizer = optim.Adam(learning_rate=1e-4)

# The monitor runs in the background and logs if memory pressure rises.
# Note: mlx_lm's train() uses a fixed batch size. For dynamic adjustment,
# use a custom training loop that reads mon.current_batch_size each step.
with guard.monitor(safe.batch_size) as mon:
    train(
        model=model, optimizer=optimizer, train_dataset=train_set,
        args=TrainingArgs(
            batch_size=safe.batch_size,
            iters=1000,
            max_seq_length=safe.seq_length,
            grad_checkpoint=safe.grad_checkpoint,
            adapter_file="adapters.safetensors",
        ),
    )

With HuggingFace Transformers (CUDA)

from memory_guard import MemoryGuard
from transformers import Trainer, TrainingArguments

guard = MemoryGuard.auto()
safe = guard.preflight(
    model_params=7e9, model_bits=16,
    hidden_dim=4096, num_heads=32, num_layers=32,
    batch_size=8, seq_length=2048,
    lora_rank=16, lora_layers=16,
)

args = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=safe.batch_size,
    gradient_accumulation_steps=safe.grad_accumulation,
    gradient_checkpointing=safe.grad_checkpoint,
    max_steps=1000,
)
trainer = Trainer(model=model, args=args, train_dataset=train_set)
trainer.train()

With Unsloth

from memory_guard import MemoryGuard
from unsloth import FastLanguageModel

guard = MemoryGuard.auto()
safe = guard.preflight(
    model_params=8e9, model_bits=4,
    hidden_dim=4096, num_heads=32, num_layers=32,
    batch_size=4, seq_length=2048,
    lora_rank=16, lora_layers=16,
)

model, tokenizer = FastLanguageModel.from_pretrained(
    "unsloth/Meta-Llama-3.1-8B-bnb-4bit",
    max_seq_length=safe.seq_length,
    load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
    model, r=safe.lora_rank, lora_alpha=safe.lora_rank * 2,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)

API Reference

MemoryGuard.auto(safety_ratio=0.80)

Create with auto-detected platform. safety_ratio controls headroom (0.80 = use 80% of available).

guard.preflight(**config) -> SafeConfig

Estimate memory and auto-downgrade. Returns safe config.

guard.monitor(batch_size) -> RuntimeMonitor

Context manager for runtime monitoring. Use mon.current_batch_size in training loop.

guard.estimate(**config) -> MemoryEstimate

Pure estimation without auto-downgrade.

estimate_training_memory(**config) -> MemoryEstimate

Standalone estimation function.

auto_downgrade(budget_mb, **config) -> DowngradeResult

Standalone downgrade function.

CUDAOOMRecovery(initial_batch_size)

CUDA-specific OOM catch-and-retry wrapper.

Estimation Accuracy

Measured accuracy on real training runs. We need your help expanding this table — see Contributing below.

Model Device Batch Seq Rank Estimated Actual Error
Qwen3.5-9B-4bit M4 Max 36GB 1 512 8 6,193 MB 7,048 MB 12.1% under
Qwen3.5-9B-4bit M4 Max 36GB 1 128 16 9,522 MB 8,879 MB 7.2% over

What's tested: LoRA fine-tuning with mlx_lm on Apple Silicon (M4 Max).

What's NOT tested yet:

  • CUDA GPUs (RTX 3060/4090, A100, H100)
  • AMD ROCm (RX 7900, MI300X)
  • Smaller devices (M1/M2 MacBook Air 8-16GB)
  • Models below 7B or above 13B
  • MoE architectures (Mixtral, DeepSeek-MoE)
  • Multi-modal models (LLaVA, Qwen-VL)
  • QLoRA with double quantization
  • DoRA, full fine-tuning
  • HuggingFace Transformers, Unsloth, PyTorch Lightning

The estimation formula is based on published research (FlashAttention, HyC-LoRA, LoRA-FA) and verified on one configuration. Auto-calibration improves accuracy after 3+ runs on any given setup.

Known Limitations

  • Single validation point: Estimation accuracy is verified on one model/device combination. Your results may differ significantly — please report them.
  • Inference workloads: The runtime monitor is built for training loops. Inference serving (vLLM, SGLang) with dynamic KV cache growth is not yet monitored.
  • Calibration cold start: Auto-calibration needs 3+ training runs on a given device before corrections kick in.
  • Custom kernels: Frameworks with heavily fused kernels (Unsloth) use less memory than the formula predicts. Calibration corrects this over time.
  • MLX Metal thread safety: mx.metal.get_active_memory() is called from a background thread. MLX's Metal backend has known thread safety limitations. Memory counter reads work in practice but aren't guaranteed thread-safe by the MLX API.
  • Windows: CUDA path uses well-tested torch.cuda APIs. The CPU-only fallback (GlobalMemoryStatusEx) hasn't been validated across Windows versions.

Contributing

Help Us Benchmark

The single most valuable contribution right now is running the benchmark on your hardware and sharing the results. This directly improves estimation accuracy for everyone.

# Install
pip install memory-guard mlx-lm

# Run with default small model (fast, ~2 minutes)
python bench/bench_accuracy.py

# Run with a specific model
python bench/bench_accuracy.py --model mlx-community/Mistral-7B-Instruct-v0.3-4bit

# Generate a pre-formatted GitHub issue with your results
python bench/bench_accuracy.py --model mlx-community/Qwen3.5-9B-MLX-4bit --submit

Then open a GitHub issue with the output. We'll add your results to the accuracy table above.

Devices we especially need data from:

  • M1/M2 MacBook Air (8GB, 16GB)
  • M3/M4 MacBook Pro (18GB, 36GB)
  • RTX 3060/3090, RTX 4070/4090
  • A100, H100
  • AMD Radeon RX 7900 / MI300X
  • Docker/Kubernetes containers with memory limits

Other Contributions

  • Framework adapters: Thin wrappers for HuggingFace Trainer, Unsloth, PyTorch Lightning
  • Inference monitoring: KV cache growth tracking for serving workloads
  • Bug reports: If the estimate was off by >30%, that's a bug — please report it with your config

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ml_memguard-0.1.0.tar.gz (46.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ml_memguard-0.1.0-py3-none-any.whl (39.2 kB view details)

Uploaded Python 3

File details

Details for the file ml_memguard-0.1.0.tar.gz.

File metadata

  • Download URL: ml_memguard-0.1.0.tar.gz
  • Upload date:
  • Size: 46.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for ml_memguard-0.1.0.tar.gz
Algorithm Hash digest
SHA256 76ef971158e29c06614cd00a4d3a1a5ecbfcb85d873eb99fb9fc0c52fcc07d22
MD5 704c60f16d35c9fd7dc6288384cd8f6b
BLAKE2b-256 751e0b3ef5884cd6866e79a88aae49f7b5eeb840ed242b32358a64da30531e31

See more details on using hashes here.

File details

Details for the file ml_memguard-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: ml_memguard-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 39.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for ml_memguard-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 be3c486939183b4987840c3bbbc6368a7dd12e6cc7bae5d7f3d9ece3f2151944
MD5 6a93066ff2899d36f9177eb83e414f24
BLAKE2b-256 10415c6706270364bdfb4dcf8371e3feaf3b4bcaefce36a35edcd9e5beb98efe

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page