Find maximum batch size, documents, and timesteps for PyTorch models
Project description
🔍 Batch Finder
Find the maximum value for any dimension your PyTorch models can handle without stopping the code.
Batch Finder detects your model’s inputs (types and shapes), fixes the sizes you choose, and searches for the largest value your run can sustain without stopping the code.
✨ Features
- 🎯 One API –
find_max_minibatchfor all workflows - 🔍 Explicit inputs – Names, dtypes, and rough shapes from
input_shapes/forward_params, plusget_model().configwhen available - 📐 Shapes – Flat tuple/list or list of tuples per tensor (e.g.
[(-1, 128, 512), (-1, 128, 512)]), text or dict whenforwardtakes several tensors, oraxis_to_maximize+fixed_axis - 🚀 Inference or training – With or without backward
- ⚙️ Tunable search –
factor_down,factor_up,n_attempts,initial_value - 🛡️ Safe runs – Cleans up after failures; returns
Noneif even size1fails - 📊 Progress – tqdm with status in the bar
📦 Installation
pip install batch-finder
Or from source:
git clone https://github.com/yourusername/batch-finder.git
cd batch-finder
pip install -e .
🚀 Quick Start
One tensor: tuple or list
Use negative integers -1 on each axis you want to maximize (the search tries a single trial size each step). In an all-integer tuple, every -1 position shares that same trial size. Use positive integers for fixed dimensions.
Other negative integers d < -1 (e.g. -2, -3) size that dimension as round(|d| × trial)—the same scaling role as negative floats below (e.g. -2 → 2× the trial size tied to -1).
If the tuple mixes integers and negative floats, you are in compact numeric mode: there must be at least one integer -1 (the searched axis). Any other dimension given as a negative float -x is sized as round(|x| × trial), where trial is the current value on the -1 axis—so |x| is the proportion you want between that axis and the searched axis (e.g. -1.5 keeps that dim about 1.5× the trial size).
from batch_finder import find_max_minibatch
def get_model():
return MyModel()
# Maximize axis 0; other dims fixed
max_val = find_max_minibatch(get_model, input_shapes=(-1, 64, 256))
# Maximize axis 2
max_val = find_max_minibatch(get_model, input_shapes=(4, 8, -1))
# One integer -1 (searched axis) + negative float: other dims scale by |float| vs. that trial
max_val = find_max_minibatch(get_model, input_shapes=(-1, 4, -1.5, 16))
Several tensors: list of shape tuples
Pass a list or tuple of per-tensor shapes (same order as forward / forward_params). Each entry follows the same rules as a single-tensor tuple (integer -1, integer multipliers -2, -3, …, or compact floats). One trial size is searched; every -1 and every d < -1 uses that trial value as above.
# Example: two tensors, same free shape pattern
max_val = find_max_minibatch(
get_model,
input_shapes=[(-1, 128, 512), (-1, 128, 512)],
forward_params=["x", "y"],
)
For symbolic names and constraints across tensors, use a string or dict below.
HuggingFace-style: axis_to_maximize + fixed_axis
from transformers import AutoModelForCausalLM
from batch_finder import find_max_minibatch
def get_model():
return AutoModelForCausalLM.from_pretrained("distilgpt2")
# ``forward_params`` defaults to GPT2-style causal LM names; pass your own list for other architectures.
max_batch = find_max_minibatch(
get_model,
axis_to_maximize="batch_size",
fixed_axis={"seq_len": 32},
)
print(f"Max batch size: {max_batch}")
Several tensors: one string
When forward takes multiple tensors, pass input_shapes as text: one (…) group per argument (same order as forward), short names for axes that must match, optional equations between names, and at least one name=-1 for the size you search.
Pattern: (dims…), (dims…), … then optional name=value rules.
- Dimensions: numbers, or names like
b,tthat repeat across tensors. - One name must be set to
-1(that is the searched size). - Rules like
t=1.5btie sizes together (non-integers are rounded for tensor shapes).
import torch
from batch_finder import find_max_minibatch
class MyModel(torch.nn.Module):
def forward(self, x, y):
# x: (23, b, t, 45), y: (b, t, 12)
...
def get_model():
return MyModel()
max_b = find_max_minibatch(
get_model,
input_shapes="(23, b, t, 45),(b, t, 12), t=1.5b, b=-1",
)
Do not combine this with tuple/list single-tensor mode or with axis_to_maximize. Pass input_shapes= as a keyword.
Several tensors: dict (named arguments)
Same idea as the string form, but keys match forward parameters. Put shared rules in "#constraints" (must include exactly one symbol=-1). Values are shape text, optionally with , int or , float for dtype.
def get_model():
return MyModel()
max_b = find_max_minibatch(
get_model,
input_shapes={
"input_ids": "(b, t), int",
"attention_mask": "(b, t), int",
"input_ids_encoder": "(d, b, t), int",
"attention_mask_encoder": "(d, b, t), int",
"labels": "(b, t)",
"#constraints": "t=2b, b=-1",
},
)
Use the literal "#constraints" for the rules entry, or import FINDER_CONSTRAINTS_KEY from batch_finder to avoid typos (CONSTRAINTS_KEY is the same string, kept for compatibility).
Custom search parameters
def get_model():
return MyModel()
max_val = find_max_minibatch(
get_model,
input_shapes=(-1, 128, 512),
initial_value=8,
n_attempts=30,
factor_down=3.0, # divide by 3 on failure
factor_up=2.0, # multiply by 2 on success
)
📖 API Reference
find_max_minibatch(...)
Find the largest workable size for the free axis without OOM.
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
get_model |
Callable[[], nn.Module] |
(required) | First argument. Fresh module per attempt. Parent calls get_model() once first, keeps only .config (if any) for vocab/shape hints, then drops weights. Picklable under spawn (module-level function, not lambda) |
forward_params |
Sequence[str] |
None |
When input_shapes is not a dict: ordered tensor names for forward. With axis_to_maximize only, defaults to DEFAULT_FORWARD_PARAMS_CAUSAL_LM (GPT2-style). Override or import that constant to tweak. |
input_shapes |
str | dict | tuple/list |
None |
String: several tensors, shared names and rules. Dict: named shapes + "#constraints". Flat tuple/list: first tensor only (ints with -1 or compact floats). Nested list/tuple of shapes: one tuple per tensor, e.g. [(-1, 128, 512), (-1, 128, 512)]. Not used together with axis_to_maximize |
axis_to_maximize |
str |
None |
Name of the axis when input_shapes is omitted, e.g. "batch_size" |
fixed_axis |
Dict[str, int] |
{} |
Fixed sizes, e.g. {"seq_len": 128} |
device |
torch.device |
auto | Device to run on |
delay |
float |
None→auto |
Auto: 3.0 s on CUDA, 0.75 s on CPU/MPS (pass a number to override) |
initial_value |
int |
None→auto |
Auto: 512 CUDA, 64 MPS, 32 CPU (pass a number to override) |
n_attempts |
int |
50 |
Maximum attempts |
inference_only |
bool |
False |
If True, no backward. If False, forward + backward. |
factor_down |
float |
2.0 |
After failure: next = value / factor_down |
factor_up |
float |
2.0 |
Used when memory_guided=False for success steps |
memory_guided |
bool |
True |
Use peak GPU memory (and optional CPU via psutil) to pick the next size; cap with max_growth_multiplier |
memory_target_fraction |
float |
0.88 |
Target peak VRAM as a fraction of total GPU memory when extrapolating |
max_growth_multiplier |
float |
6.0 |
Max single-step increase after success |
cuda_mem_devices |
int, list, or "all" |
None |
Which GPUs to read peak memory on; default = search device’s index. "all" or [0,1,…] for multi-GPU bottleneck |
use_subprocess |
bool |
None |
Default: subprocess on Linux/macOS for CUDA, CPU, and MPS (worker may be OOM-killed; parent continues). Set BATCH_FINDER_SUBPROCESS=0 for in-process on very tight hosts. Windows: in-process |
Returns: Shape tuple, tuple of shape tuples (multi-tensor list input_shapes), or int (string/dict input_shapes or axis_to_maximize), or None if nothing worked.
Modes:
- Tuple/list: first
forwardargument, or a list/tuple of shape tuples for several arguments;-1marks the free axis, or add negative floats to scale other axes off that trial size. - String: one shape group per tensor; names, one
=-1, optional equations between names. - Dict: like the string form, keys = parameter names,
"#constraints"for the rules. axis_to_maximize+fixed_axis: when you skipinput_shapes.
Example output (axis_to_maximize + fixed_axis):
--- Detected inputs (type, estimated shape) ---
input_ids: integer, (32, 64)
attention_mask: integer, (32, 64)
---
batch_size fixed={'seq_len': 32}: 100%|████████████████████| 22/50 [01:26<00:00, 3.9s/it, gpus=1, i=22/50, max_ok=1919, min_fail=1920, status=✅, value=1919
✅ Max value that passed: 1919
🔧 How It Works
- Inputs – Uses
input_shapesdict keys orforward_params(no live module in the parent during the search). - Types – Integers for
*ids,*mask,labels; floats otherwise. - Shapes – From optional
config, or fromget_model().config(one probe if you omitconfig) plus argument names. - Search – On success, grow; on failure, shrink. Stops at failure at size
1or whenn_attemptsis reached. - Loss – Uses
output.lossif present, otherwise sums output tensors.
⚠️ Important Notes
- Memory: Lower
initial_valueon small GPUs. - Speed:
inference_only=Trueis faster. - Training:
inference_only=Falseruns backward too. - Size 1: If size
1fails, the function returnsNone. - OOM: CUDA OOM and similar errors are caught and the search continues. On Linux/macOS, attempts default to a subprocess so the parent keeps running if the host kills the worker (SIGKILL). Set
BATCH_FINDER_SUBPROCESS=0for in-process attempts if repeatedspawnreloads are worse on your host. Each worker run ends withdel model,gc.collect(), and CUDA/MPS cache clears where applicable. After failures, the same cleanup runs. - Defaults: Omitting
initial_valueanddelaypicks smaller starts and shorter pauses on CPU/MPS than on CUDA. - Login nodes: Prefer a real GPU job for meaningful limits. On CPU-only hosts, use a smaller
initial_valueand/orinference_only=Trueif RAM is tight. - Memory-guided steps: After each good run, peak GPU memory vs total guides the next step (capped by
max_growth_multiplier). Installpsutilfor a rough CPU hint. Setmemory_guided=Falseto use onlyfactor_up/factor_down. - Multi-GPU: Use
cuda_mem_devices="all"or a list so every GPU is considered; the smallest safe step applies.
🤝 Contributing
Contributions are welcome. Please open a Pull Request.
📝 License
MIT — see the LICENSE file.
Made with ❤️ for the PyTorch community
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file batch_finder-0.4.0.tar.gz.
File metadata
- Download URL: batch_finder-0.4.0.tar.gz
- Upload date:
- Size: 29.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
be3e4af2186d5716d31ad0dab3ebdb0b6d9e08c50316d6ce2df37fdb6c707ad2
|
|
| MD5 |
b6ff2f8c9dac9a7093c9176f60f27172
|
|
| BLAKE2b-256 |
4f69dcbd5fa150f9ca98e1503ee8f8581125399a59e7f5982f37a50cce3742f6
|
File details
Details for the file batch_finder-0.4.0-py3-none-any.whl.
File metadata
- Download URL: batch_finder-0.4.0-py3-none-any.whl
- Upload date:
- Size: 25.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
61550defb2b5b22a7cb4bf32af03fb18c8aee421d18892a7da5a578127c1a9c6
|
|
| MD5 |
2a6b82c2e48fdf6ece53e0303c76a790
|
|
| BLAKE2b-256 |
784c02ff982e03a1591e221cab21958e0ce8f8a808171bb1621cd70239b5d2d5
|