Skip to main content

Find the maximum batch size that fits in GPU memory. Binary search with OOM recovery.

Project description

batch-probe

CI PyPI Python License: MIT

Find the maximum batch size that fits in GPU memory.

Binary search with OOM recovery, configurable safety headroom, no framework required.

The Problem

Every ML practitioner has done this:

batch_size = 64   # OOM
batch_size = 32   # OOM
batch_size = 16   # OOM
batch_size = 8    # works... but am I leaving GPU memory on the table?

batch-probe automates this. It binary-searches for the largest batch size your model can handle, with a safety margin so you don't OOM during real training.

Install

pip install batch-probe

Quick Start

from batch_probe import probe_batch_size

batch_size = probe_batch_size(
    model,
    lambda bs: {
        "input_ids": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
        "attention_mask": torch.ones(bs, 512, dtype=torch.long, device="cuda"),
    },
)
# batch-probe: probing batch size (mode=train, range=[1, 4096], headroom=20%)... max=6, safe=4

That's it. Three lines. Works with any nn.Module.

Usage

Encoder models (BERT, RoBERTa, etc.)

batch_size = probe_batch_size(
    model,
    lambda bs: {
        "input_ids": torch.zeros(bs, 128, dtype=torch.long, device="cuda"),
        "attention_mask": torch.ones(bs, 128, dtype=torch.long, device="cuda"),
    },
    mode="train",
)

Seq2seq models (T5, BART, etc.)

batch_size = probe_batch_size(
    model,
    lambda bs: {
        "input_ids": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
        "attention_mask": torch.ones(bs, 512, dtype=torch.long, device="cuda"),
        "labels": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
    },
    mode="train",
)

Vision models

batch_size = probe_batch_size(
    model,
    lambda bs: {"x": torch.randn(bs, 3, 224, 224, device="cuda")},
    mode="infer",
)

Inference-only probing

Inference uses ~2-4x less memory than training (no gradients stored):

infer_batch = probe_batch_size(model, input_fn, mode="infer")
train_batch = probe_batch_size(model, input_fn, mode="train")
# infer_batch >> train_batch

Custom headroom

Default is 20% safety margin. Adjust for your risk tolerance:

# Conservative (40% headroom) — for long training runs
batch_size = probe_batch_size(model, input_fn, headroom=0.4)

# Aggressive (5% headroom) — squeeze every last sample
batch_size = probe_batch_size(model, input_fn, headroom=0.05)

Caching

Use cached_probe to avoid re-probing the same model:

from batch_probe import cached_probe, clear_cache

batch_size = cached_probe(model, input_fn, mode="train")  # probes
batch_size = cached_probe(model, input_fn, mode="train")  # cache hit

clear_cache()  # reset if model changed

How It Works

  1. Binary search between low (default 1) and high (default 4096)
  2. At each midpoint, create dummy tensors via your input_fn
  3. Run a forward pass (+ backward pass in train mode)
  4. If OOM: upper bound ← midpoint − 1, clean GPU memory
  5. If success: lower bound ← midpoint + 1
  6. Return int(max_successful × (1 − headroom))

The OOM recovery uses gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() to fully reclaim memory between iterations.

vs. Alternatives

Feature batch-probe Lightning BatchSizeFinder HF auto_find_batch_size
Works with raw PyTorch Yes No (needs LightningModule) No (needs HF Trainer)
Algorithm Binary search Power-of-2 scaling Halve on OOM
Configurable headroom Yes No No
Train + infer modes Yes Train only Train only
Dependencies torch only pytorch-lightning accelerate

API Reference

probe_batch_size(model, input_fn, *, mode, low, high, headroom, device, verbose)

Find the maximum safe batch size.

  • model (nn.Module): Your model, already on the target device.
  • input_fn (Callable[[int], dict[str, Tensor]]): Takes batch size, returns dict of tensors for model(**inputs).
  • mode ("train" | "infer"): Train mode runs forward + backward. Default: "train".
  • low (int): Minimum batch size. Default: 1.
  • high (int): Upper bound for search. Default: 4096.
  • headroom (float): Safety margin. Default: 0.2 (20%).
  • device (str | torch.device | None): Override device. Default: model's device.
  • verbose (bool): Print progress. Default: True.

Returns: int — safe batch size.

cached_probe(model, input_fn, *, mode, **kwargs)

Same as probe_batch_size but caches results keyed on model class, param count, input shapes, and mode.

clear_cache()

Clear all cached probe results.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

batch_probe-0.2.0.tar.gz (11.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

batch_probe-0.2.0-py3-none-any.whl (8.5 kB view details)

Uploaded Python 3

File details

Details for the file batch_probe-0.2.0.tar.gz.

File metadata

  • Download URL: batch_probe-0.2.0.tar.gz
  • Upload date:
  • Size: 11.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for batch_probe-0.2.0.tar.gz
Algorithm Hash digest
SHA256 0ad7d4104f42adc7a8e89797731857f3c004f9acf9c00c5cb02812fb163c4119
MD5 4d9362351ec275f3afa21baa525717c7
BLAKE2b-256 63b2fcd5707e3f1f1b52d9b4a58567e1fa4e7340f2058d1b03e4e5a084090b71

See more details on using hashes here.

File details

Details for the file batch_probe-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: batch_probe-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 8.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for batch_probe-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b19956c7c102f57eee227fedd6c6be2e85e13c613a8b1feea34f96318c691342
MD5 6a9d871d98d74a5a116146a1eefd6684
BLAKE2b-256 113498cbe71a4a88b2e5673bba8a647476ddc9e62ec240bc34e2cc68bd3521d5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page