Cross-platform memory guard for ML training. Prevents OOM crashes on Apple Silicon, CUDA, and CPU by proactive estimation and runtime monitoring.
Project description
memory-guard
Cross-platform memory guard for ML training. Prevents OOM crashes on Apple Silicon, CUDA, and CPU.
No more frozen Macs. No more CUDA out of memory. Just training that works.
pip install memory-guard
The Problem
- Apple Silicon: No OOM exception exists. When you exceed memory, macOS silently swaps to disk, your Mac freezes for minutes, and eventually the OS kills your process. Every ML practitioner on Mac has experienced this.
- CUDA:
torch.cuda.OutOfMemoryErrorcrashes your training run. You restart, guess a smaller batch size, and pray. - Containers: cgroups silently kill your process with no warning when you hit the memory limit.
Existing solutions (PyTorch Lightning BatchSizeFinder, HuggingFace accelerate) are CUDA-only and reactive — they catch OOM exceptions that don't exist on Apple Silicon.
The Solution
memory-guard is proactive, not reactive. It estimates peak memory before training starts and auto-adjusts your config to fit. During training, it monitors memory pressure and dynamically downgrades if needed.
from memory_guard import MemoryGuard
guard = MemoryGuard.auto()
# Pre-flight: estimate memory and auto-downgrade config
safe = guard.preflight(
model_params=9_000_000_000, # 9B parameter model
model_bits=4, # 4-bit quantized
hidden_dim=4096,
num_heads=32,
num_layers=32,
batch_size=4,
seq_length=2048,
lora_rank=32,
lora_layers=16,
)
print(safe)
# SafeConfig (FITS):
# batch_size: 2 <- auto-reduced from 4
# grad_checkpoint: True <- auto-enabled
# grad_accumulation:4 <- compensates for smaller batch
# estimated memory: 3835 MB
# budget: 4643 MB
# Runtime monitoring: polls memory pressure every 5s
with guard.monitor(safe.batch_size) as mon:
for step in range(1000):
# Batch size may decrease mid-training if pressure rises
train_step(batch_size=mon.current_batch_size)
Features
| Feature | Apple Silicon | CUDA | Linux CPU | Windows |
|---|---|---|---|---|
| Proactive memory estimation | Yes | Yes | Yes | Yes |
| Auto-downgrade config | Yes | Yes | Yes | Yes |
| Runtime pressure monitoring | Yes (Mach kernel + MLX Metal) | Yes (torch.cuda) | Yes (PSI, cgroups) | Yes (kernel32) |
| MLX Metal ground-truth | Yes (mx.metal.get_active_memory) | N/A | N/A | N/A |
| OOM catch & retry | N/A (no OOM on Metal) | Yes | N/A | N/A |
| Container-aware (cgroups v1/v2) | N/A | Yes | Yes | N/A |
| Auto-calibration | Yes | Yes | Yes | Yes |
| FlashAttention-aware | Yes | Yes | Yes | Yes |
| GQA / MoE / Multi-modal | Yes | Yes | Yes | Yes |
How It Works
1. Proactive Estimation
Calculates peak memory from model architecture, accounting for:
- Per-projection LoRA input buffers (Q, K, V, O)
- FlashAttention O(n) vs standard O(n^2) attention scores
- GQA-aware KV cache (uses
num_kv_heads, notnum_heads) - MoE routing buffers and active expert activations
- Optimizer states (Adam 3x, SGD 2x, Adafactor 1.5x)
- MLX lazy evaluation discount (20% reduction on Apple Silicon)
- Framework overhead (25% proportional + 400MB fixed runtime cost)
With gradient checkpointing, activation memory drops to sqrt(layers).
2. Auto-Downgrade (quality-preserving order)
When estimate exceeds budget (available × 80%):
- Enable gradient checkpointing (free quality, ~40% activation savings)
- Halve batch size (compensate with gradient accumulation)
- Halve sequence length
- Halve LoRA rank
- Halve LoRA layers
3. Runtime Monitoring
Background thread polls memory pressure every 5 seconds:
- Apple Silicon:
mx.metal.get_active_memory()(ground-truth from Metal allocator), withkern.memorystatus_levelas fallback. Detects the monotonic memory growth pattern from mlx-examples#1262. - CUDA:
torch.cuda.memory_allocated()vs total VRAM - Linux:
/proc/pressure/memory(PSI), cgroup-aware (memory.highpreferred overmemory.max) - Windows:
GlobalMemoryStatusEx
When pressure exceeds 85%, batch size is halved mid-training.
4. Auto-Calibration
After each training run, the actual peak memory (from mx.metal.get_peak_memory() or torch.cuda.max_memory_allocated()) is recorded alongside the formula estimate. After 3+ runs, a median correction factor is applied to future estimates, narrowing the gap between predicted and actual memory usage over time.
Framework Integration
With mlx_lm (Apple Silicon)
import mlx.optimizers as optim
from memory_guard import MemoryGuard
from mlx_lm import load
from mlx_lm.tuner.trainer import train, TrainingArgs
from mlx_lm.tuner.utils import linear_to_lora_layers
guard = MemoryGuard.auto()
model, tokenizer = load("mlx-community/Qwen3.5-9B-MLX-4bit")
safe = guard.preflight(
model_params=9e9, model_bits=4,
hidden_dim=4096, num_heads=32, num_layers=32,
batch_size=4, seq_length=2048,
lora_rank=32, lora_layers=16,
)
model.freeze()
linear_to_lora_layers(
model, safe.lora_layers,
{"rank": safe.lora_rank, "scale": 20.0, "dropout": 0.0},
)
optimizer = optim.Adam(learning_rate=1e-4)
# The monitor runs in the background and logs if memory pressure rises.
# Note: mlx_lm's train() uses a fixed batch size. For dynamic adjustment,
# use a custom training loop that reads mon.current_batch_size each step.
with guard.monitor(safe.batch_size) as mon:
train(
model=model, optimizer=optimizer, train_dataset=train_set,
args=TrainingArgs(
batch_size=safe.batch_size,
iters=1000,
max_seq_length=safe.seq_length,
grad_checkpoint=safe.grad_checkpoint,
adapter_file="adapters.safetensors",
),
)
With HuggingFace Transformers (CUDA)
from memory_guard import MemoryGuard
from transformers import Trainer, TrainingArguments
guard = MemoryGuard.auto()
safe = guard.preflight(
model_params=7e9, model_bits=16,
hidden_dim=4096, num_heads=32, num_layers=32,
batch_size=8, seq_length=2048,
lora_rank=16, lora_layers=16,
)
args = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=safe.batch_size,
gradient_accumulation_steps=safe.grad_accumulation,
gradient_checkpointing=safe.grad_checkpoint,
max_steps=1000,
)
trainer = Trainer(model=model, args=args, train_dataset=train_set)
trainer.train()
With Unsloth
from memory_guard import MemoryGuard
from unsloth import FastLanguageModel
guard = MemoryGuard.auto()
safe = guard.preflight(
model_params=8e9, model_bits=4,
hidden_dim=4096, num_heads=32, num_layers=32,
batch_size=4, seq_length=2048,
lora_rank=16, lora_layers=16,
)
model, tokenizer = FastLanguageModel.from_pretrained(
"unsloth/Meta-Llama-3.1-8B-bnb-4bit",
max_seq_length=safe.seq_length,
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model, r=safe.lora_rank, lora_alpha=safe.lora_rank * 2,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
API Reference
MemoryGuard.auto(safety_ratio=0.80)
Create with auto-detected platform. safety_ratio controls headroom (0.80 = use 80% of available).
guard.preflight(**config) -> SafeConfig
Estimate memory and auto-downgrade. Returns safe config.
guard.monitor(batch_size) -> RuntimeMonitor
Context manager for runtime monitoring. Use mon.current_batch_size in training loop.
guard.estimate(**config) -> MemoryEstimate
Pure estimation without auto-downgrade.
estimate_training_memory(**config) -> MemoryEstimate
Standalone estimation function.
auto_downgrade(budget_mb, **config) -> DowngradeResult
Standalone downgrade function.
CUDAOOMRecovery(initial_batch_size)
CUDA-specific OOM catch-and-retry wrapper.
Estimation Accuracy
Measured accuracy on real training runs. We need your help expanding this table — see Contributing below.
| Model | Device | Batch | Seq | Rank | Estimated | Actual | Error |
|---|---|---|---|---|---|---|---|
| Qwen3.5-9B-4bit | M4 Max 36GB | 1 | 512 | 8 | 6,193 MB | 7,048 MB | 12.1% under |
| Qwen3.5-9B-4bit | M4 Max 36GB | 1 | 128 | 16 | 9,522 MB | 8,879 MB | 7.2% over |
What's tested: LoRA fine-tuning with mlx_lm on Apple Silicon (M4 Max).
What's NOT tested yet:
- CUDA GPUs (RTX 3060/4090, A100, H100)
- AMD ROCm (RX 7900, MI300X)
- Smaller devices (M1/M2 MacBook Air 8-16GB)
- Models below 7B or above 13B
- MoE architectures (Mixtral, DeepSeek-MoE)
- Multi-modal models (LLaVA, Qwen-VL)
- QLoRA with double quantization
- DoRA, full fine-tuning
- HuggingFace Transformers, Unsloth, PyTorch Lightning
The estimation formula is based on published research (FlashAttention, HyC-LoRA, LoRA-FA) and verified on one configuration. Auto-calibration improves accuracy after 3+ runs on any given setup.
Known Limitations
- Single validation point: Estimation accuracy is verified on one model/device combination. Your results may differ significantly — please report them.
- Inference workloads: The runtime monitor is built for training loops. Inference serving (vLLM, SGLang) with dynamic KV cache growth is not yet monitored.
- Calibration cold start: Auto-calibration needs 3+ training runs on a given device before corrections kick in.
- Custom kernels: Frameworks with heavily fused kernels (Unsloth) use less memory than the formula predicts. Calibration corrects this over time.
- MLX Metal thread safety:
mx.metal.get_active_memory()is called from a background thread. MLX's Metal backend has known thread safety limitations. Memory counter reads work in practice but aren't guaranteed thread-safe by the MLX API. - Windows: CUDA path uses well-tested
torch.cudaAPIs. The CPU-only fallback (GlobalMemoryStatusEx) hasn't been validated across Windows versions.
Contributing
Help Us Benchmark
The single most valuable contribution right now is running the benchmark on your hardware and sharing the results. This directly improves estimation accuracy for everyone.
# Install
pip install memory-guard mlx-lm
# Run with default small model (fast, ~2 minutes)
python bench/bench_accuracy.py
# Run with a specific model
python bench/bench_accuracy.py --model mlx-community/Mistral-7B-Instruct-v0.3-4bit
# Generate a pre-formatted GitHub issue with your results
python bench/bench_accuracy.py --model mlx-community/Qwen3.5-9B-MLX-4bit --submit
Then open a GitHub issue with the output. We'll add your results to the accuracy table above.
Devices we especially need data from:
- M1/M2 MacBook Air (8GB, 16GB)
- M3/M4 MacBook Pro (18GB, 36GB)
- RTX 3060/3090, RTX 4070/4090
- A100, H100
- AMD Radeon RX 7900 / MI300X
- Docker/Kubernetes containers with memory limits
Other Contributions
- Framework adapters: Thin wrappers for HuggingFace Trainer, Unsloth, PyTorch Lightning
- Inference monitoring: KV cache growth tracking for serving workloads
- Bug reports: If the estimate was off by >30%, that's a bug — please report it with your config
License
Apache 2.0
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ml_memguard-0.1.0.tar.gz.
File metadata
- Download URL: ml_memguard-0.1.0.tar.gz
- Upload date:
- Size: 46.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
76ef971158e29c06614cd00a4d3a1a5ecbfcb85d873eb99fb9fc0c52fcc07d22
|
|
| MD5 |
704c60f16d35c9fd7dc6288384cd8f6b
|
|
| BLAKE2b-256 |
751e0b3ef5884cd6866e79a88aae49f7b5eeb840ed242b32358a64da30531e31
|
File details
Details for the file ml_memguard-0.1.0-py3-none-any.whl.
File metadata
- Download URL: ml_memguard-0.1.0-py3-none-any.whl
- Upload date:
- Size: 39.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
be3c486939183b4987840c3bbbc6368a7dd12e6cc7bae5d7f3d9ece3f2151944
|
|
| MD5 |
6a93066ff2899d36f9177eb83e414f24
|
|
| BLAKE2b-256 |
10415c6706270364bdfb4dcf8371e3feaf3b4bcaefce36a35edcd9e5beb98efe
|