Skip to main content

GPU memory probing and thermal-aware CPU thread control. Works with PyTorch, CuPy, JAX, or any GPU framework.

Project description

batch-probe

CI PyPI Python License: MIT

GPU memory probing and thermal-aware CPU thread control.

Binary search with OOM recovery, configurable safety headroom, no framework required. New in v0.4.0: thermal-aware thread tuning for CPU workloads.

The Problem

Every ML practitioner has done this:

batch_size = 64   # OOM
batch_size = 32   # OOM
batch_size = 16   # OOM
batch_size = 8    # works... but am I leaving GPU memory on the table?

batch-probe automates this. It binary-searches for the largest batch size your model can handle, with a safety margin so you don't OOM during real training.

Install

pip install batch-probe

Quick Start

from batch_probe import probe_batch_size

batch_size = probe_batch_size(
    model,
    lambda bs: {
        "input_ids": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
        "attention_mask": torch.ones(bs, 512, dtype=torch.long, device="cuda"),
    },
)
# batch-probe: probing batch size (mode=train, range=[1, 4096], headroom=20%)... max=6, safe=4

That's it. Three lines. Works with any nn.Module.

Usage

Encoder models (BERT, RoBERTa, etc.)

batch_size = probe_batch_size(
    model,
    lambda bs: {
        "input_ids": torch.zeros(bs, 128, dtype=torch.long, device="cuda"),
        "attention_mask": torch.ones(bs, 128, dtype=torch.long, device="cuda"),
    },
    mode="train",
)

Seq2seq models (T5, BART, etc.)

batch_size = probe_batch_size(
    model,
    lambda bs: {
        "input_ids": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
        "attention_mask": torch.ones(bs, 512, dtype=torch.long, device="cuda"),
        "labels": torch.zeros(bs, 512, dtype=torch.long, device="cuda"),
    },
    mode="train",
)

Vision models

batch_size = probe_batch_size(
    model,
    lambda bs: {"x": torch.randn(bs, 3, 224, 224, device="cuda")},
    mode="infer",
)

Inference-only probing

Inference uses ~2-4x less memory than training (no gradients stored):

infer_batch = probe_batch_size(model, input_fn, mode="infer")
train_batch = probe_batch_size(model, input_fn, mode="train")
# infer_batch >> train_batch

Custom headroom

Default is 20% safety margin. Adjust for your risk tolerance:

# Conservative (40% headroom) — for long training runs
batch_size = probe_batch_size(model, input_fn, headroom=0.4)

# Aggressive (5% headroom) — squeeze every last sample
batch_size = probe_batch_size(model, input_fn, headroom=0.05)

Caching

Use cached_probe to avoid re-probing the same model:

from batch_probe import cached_probe, clear_cache

batch_size = cached_probe(model, input_fn, mode="train")  # probes
batch_size = cached_probe(model, input_fn, mode="train")  # cache hit

clear_cache()  # reset if model changed

How It Works

  1. Binary search between low (default 1) and high (default 4096)
  2. At each midpoint, create dummy tensors via your input_fn
  3. Run a forward pass (+ backward pass in train mode)
  4. If OOM: upper bound ← midpoint − 1, clean GPU memory
  5. If success: lower bound ← midpoint + 1
  6. Return int(max_successful × (1 − headroom))

The OOM recovery uses gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() to fully reclaim memory between iterations.

vs. Alternatives

Feature batch-probe Lightning BatchSizeFinder HF auto_find_batch_size
Works with raw PyTorch Yes No (needs LightningModule) No (needs HF Trainer)
Algorithm Binary search Power-of-2 scaling Halve on OOM
Configurable headroom Yes No No
Train + infer modes Yes Train only Train only
Dependencies torch only pytorch-lightning accelerate

API Reference

probe_batch_size(model, input_fn, *, mode, low, high, headroom, device, verbose)

Find the maximum safe batch size.

  • model (nn.Module): Your model, already on the target device.
  • input_fn (Callable[[int], dict[str, Tensor]]): Takes batch size, returns dict of tensors for model(**inputs).
  • mode ("train" | "infer"): Train mode runs forward + backward. Default: "train".
  • low (int): Minimum batch size. Default: 1.
  • high (int): Upper bound for search. Default: 4096.
  • headroom (float): Safety margin. Default: 0.2 (20%).
  • device (str | torch.device | None): Override device. Default: model's device.
  • verbose (bool): Print progress. Default: True.

Returns: int — safe batch size.

cached_probe(model, input_fn, *, mode, **kwargs)

Same as probe_batch_size but caches results keyed on model class, param count, input shapes, and mode.

clear_cache()

Clear all cached probe results.


Thermal-Aware Thread Control (v0.4.0)

For CPU-bound workloads, batch-probe can find the maximum thread count that keeps your CPU under a target temperature, and continuously adjust it during long-running jobs.

probe_threads(work_fn, *, max_temp, low, high, settle_time, work_time, cooldown_time, verbose)

One-shot binary search for the maximum safe thread count. Runs your workload at different thread counts and reads CPU temperature to find the thermal limit.

from batch_probe import probe_threads
import numpy as np

def stress(n):
    import os
    os.environ["OMP_NUM_THREADS"] = str(n)
    for _ in range(100):
        a = np.random.randn(2000, 2000)
        _ = a @ a.T

threads = probe_threads(stress, max_temp=85.0)
print(f"Safe thread count: {threads}")
  • work_fn (Callable[[int], None]): Function that runs a CPU workload using the given number of threads.
  • max_temp (float): Maximum acceptable CPU temperature in Celsius. Default: 85.0.
  • low (int): Minimum thread count. Default: 1.
  • high (int | None): Maximum thread count. Default: os.cpu_count().
  • settle_time (float): Seconds to wait before reading temperature. Default: 5.0.
  • work_time (float): Seconds to run the workload per probe. Default: 10.0.
  • cooldown_time (float): Seconds to wait between probes. Default: 15.0.
  • verbose (bool): Print progress. Default: True.

Returns: int -- safe thread count.

ThermalController(target_temp, max_threads, min_threads, poll_interval, Kp, Ki, Kd, lookahead, verbose)

Continuous adaptive thread controller using a Kalman-filtered thermal state estimator. Runs a background thread that reads CPU temperature and adjusts the recommended thread count using PI+D control with feedforward.

from batch_probe import ThermalController

ctrl = ThermalController(target_temp=82.0, max_threads=48)
ctrl.start()

# In your workload loop:
while work_remaining:
    n = ctrl.get_threads()
    run_workload(n_threads=n)

ctrl.stop()
print(ctrl.summary())
  • target_temp (float): Desired CPU temperature setpoint. Default: 82.0.
  • max_threads (int | None): Maximum thread count. Default: os.cpu_count().
  • min_threads (int): Minimum thread count. Default: 1.
  • poll_interval (float): Seconds between sensor reads. Default: 2.0.
  • Kp, Ki, Kd (float): PID controller gains. Defaults: 3.0, 0.1, 10.0.
  • lookahead (float): Seconds to predict ahead for proactive control. Default: 5.0.

Methods:

  • start() -- Start background thermal monitoring.
  • stop() -- Stop the background thread.
  • get_threads() -- Get the current recommended thread count (thread-safe).
  • summary() -- Return a dict of control history statistics (temp mean/max/min, threads mean/min/max, time over target).

ThermalJobManager(target_temp, max_concurrent, settle_time, poll_interval, cooldown_margin, verbose)

Manages parallel subprocess jobs with thermal throttling. Launches jobs up to max_concurrent, but pauses new launches when CPU temperature exceeds the target.

from batch_probe import ThermalJobManager

jobs = [
    ("dataset_A", ["python", "train.py", "--data", "A"]),
    ("dataset_B", ["python", "train.py", "--data", "B"]),
    ("dataset_C", ["python", "train.py", "--data", "C"]),
]

mgr = ThermalJobManager(target_temp=85.0, max_concurrent=4)
results = mgr.run(jobs, cwd="/path/to/workdir")
# results: {"dataset_A": 0, "dataset_B": 0, "dataset_C": 0}
  • target_temp (float): Maximum CPU temperature. Default: 85.0.
  • max_concurrent (int): Maximum simultaneous jobs. Default: 4.
  • settle_time (float): Seconds to wait after launch before reading temp. Default: 10.0.
  • poll_interval (float): Seconds between job status checks. Default: 5.0.
  • cooldown_margin (float): Must be this many degrees below target to launch. Default: 3.0.

Method:

  • run(jobs, cwd=None, log_dir=None) -- Run all jobs. Returns dict[str, int] mapping job name to exit code. Each job gets a log file at {log_dir}/{name}.log.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

batch_probe-0.4.0.tar.gz (28.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

batch_probe-0.4.0-py3-none-any.whl (19.2 kB view details)

Uploaded Python 3

File details

Details for the file batch_probe-0.4.0.tar.gz.

File metadata

  • Download URL: batch_probe-0.4.0.tar.gz
  • Upload date:
  • Size: 28.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for batch_probe-0.4.0.tar.gz
Algorithm Hash digest
SHA256 196a5898b9feaef3da83e5a2be346999c6e53d21ef332177fdaff4a9942f74c7
MD5 d91fe123a8cf00b5655b16485af990a4
BLAKE2b-256 168173639f06784d3a7f93317009bca6c65076d0bc9807ad5fd22559b7bf8eae

See more details on using hashes here.

File details

Details for the file batch_probe-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: batch_probe-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 19.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for batch_probe-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d14754dcab1f6a43f9ef9618c00b69778ef31eca1c7421ef6e4165a590470c38
MD5 8a2f0626b601374661a7169ba01171d6
BLAKE2b-256 1a4f0f31117153d68d3e16ecdf71d4c551d60151e16b02044e6c183eec11757e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page