Find the maximum batch size that fits in GPU memory. Binary search with OOM recovery.
Project description
batch-probe
Find the maximum batch size that fits in GPU memory.
Binary search with OOM recovery, configurable safety headroom, no framework required.
The Problem
Every ML practitioner has done this:
batch_size = 64 # OOM
batch_size = 32 # OOM
batch_size = 16 # OOM
batch_size = 8 # works... but am I leaving GPU memory on the table?
batch-probe automates this. It binary-searches for the largest batch size your model can handle, with a safety margin so you don't OOM during real training.
Install
pip install batch-probe
Quick Start
from torch_probe import probe_batch_size
batch_size = probe_batch_size(
model,
lambda bs: {
"input_ids": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
"attention_mask": torch.ones(bs, 512, dtype=torch.long, device="cuda"),
},
)
# torch-probe: probing batch size (mode=train, range=[1, 4096], headroom=20%)... max=6, safe=4
That's it. Three lines. Works with any nn.Module.
Usage
Encoder models (BERT, RoBERTa, etc.)
batch_size = probe_batch_size(
model,
lambda bs: {
"input_ids": torch.zeros(bs, 128, dtype=torch.long, device="cuda"),
"attention_mask": torch.ones(bs, 128, dtype=torch.long, device="cuda"),
},
mode="train",
)
Seq2seq models (T5, BART, etc.)
batch_size = probe_batch_size(
model,
lambda bs: {
"input_ids": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
"attention_mask": torch.ones(bs, 512, dtype=torch.long, device="cuda"),
"labels": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
},
mode="train",
)
Vision models
batch_size = probe_batch_size(
model,
lambda bs: {"x": torch.randn(bs, 3, 224, 224, device="cuda")},
mode="infer",
)
Inference-only probing
Inference uses ~2-4x less memory than training (no gradients stored):
infer_batch = probe_batch_size(model, input_fn, mode="infer")
train_batch = probe_batch_size(model, input_fn, mode="train")
# infer_batch >> train_batch
Custom headroom
Default is 20% safety margin. Adjust for your risk tolerance:
# Conservative (40% headroom) — for long training runs
batch_size = probe_batch_size(model, input_fn, headroom=0.4)
# Aggressive (5% headroom) — squeeze every last sample
batch_size = probe_batch_size(model, input_fn, headroom=0.05)
Caching
Use cached_probe to avoid re-probing the same model:
from torch_probe import cached_probe, clear_cache
batch_size = cached_probe(model, input_fn, mode="train") # probes
batch_size = cached_probe(model, input_fn, mode="train") # cache hit
clear_cache() # reset if model changed
How It Works
- Binary search between
low(default 1) andhigh(default 4096) - At each midpoint, create dummy tensors via your
input_fn - Run a forward pass (+ backward pass in train mode)
- If OOM: upper bound ← midpoint − 1, clean GPU memory
- If success: lower bound ← midpoint + 1
- Return
int(max_successful × (1 − headroom))
The OOM recovery uses gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() to fully reclaim memory between iterations.
vs. Alternatives
| Feature | batch-probe | Lightning BatchSizeFinder | HF auto_find_batch_size |
|---|---|---|---|
| Works with raw PyTorch | Yes | No (needs LightningModule) | No (needs HF Trainer) |
| Algorithm | Binary search | Power-of-2 scaling | Halve on OOM |
| Configurable headroom | Yes | No | No |
| Train + infer modes | Yes | Train only | Train only |
| Dependencies | torch only | pytorch-lightning | accelerate |
API Reference
probe_batch_size(model, input_fn, *, mode, low, high, headroom, device, verbose)
Find the maximum safe batch size.
- model (
nn.Module): Your model, already on the target device. - input_fn (
Callable[[int], dict[str, Tensor]]): Takes batch size, returns dict of tensors formodel(**inputs). - mode (
"train"|"infer"): Train mode runs forward + backward. Default:"train". - low (
int): Minimum batch size. Default:1. - high (
int): Upper bound for search. Default:4096. - headroom (
float): Safety margin. Default:0.2(20%). - device (
str | torch.device | None): Override device. Default: model's device. - verbose (
bool): Print progress. Default:True.
Returns: int — safe batch size.
cached_probe(model, input_fn, *, mode, **kwargs)
Same as probe_batch_size but caches results keyed on model class, param count, input shapes, and mode.
clear_cache()
Clear all cached probe results.
License
MIT
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file batch_probe-0.1.0.tar.gz.
File metadata
- Download URL: batch_probe-0.1.0.tar.gz
- Upload date:
- Size: 8.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c0c54942a47bb9e66e11a3c1cd9dc94230b4d29509c435fba597dce51afc5278
|
|
| MD5 |
d702d44dc48d0d65287103b3265681f4
|
|
| BLAKE2b-256 |
0858d9f0c8caf6ad2a73394616deb27409f94ef016b209a921d4b7b357638110
|
File details
Details for the file batch_probe-0.1.0-py3-none-any.whl.
File metadata
- Download URL: batch_probe-0.1.0-py3-none-any.whl
- Upload date:
- Size: 8.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
65de52dbd7eda693b66020bde3d1389f2f713c92beae3d5607e88e9a0cb3cdf0
|
|
| MD5 |
b1e94d6a13730dd88cab502b5a860a03
|
|
| BLAKE2b-256 |
6880ed00f6c9a729744df7b374219ffa02ffe5264f3f706553b29d29fb030f42
|