Find maximum batch size, documents, and timesteps for PyTorch models
Project description
🔍 Batch Finder
Find the maximum value for any dimension your PyTorch models can handle without running out of memory.
Batch Finder automatically detects your model's inputs (type and shape), fixes the dimensions you specify, and finds the maximum value for the remaining axis using a configurable search strategy.
✨ Features
- 🎯 Single unified API – One function
find_max_minibatchfor all cases - 🔍 Automatic input detection – Infers input names, types (int/float), and shapes from the model
- 📐 Flexible shape specification –
input_shapesas tuple/list (single tensor), DSL string (multi-tensor), oraxis_to_maximize+fixed_axis - 🚀 Inference or full backward – Test with or without gradients
- ⚙️ Configurable search – Customize
factor_down,factor_up,n_attempts,initial_value - 🛡️ Safe testing – Error handling, memory cleanup, returns
Noneif fails at value 1 - 📊 Progress tracking – tqdm progress bar with status in postfix
📦 Installation
pip install batch-finder
Or install from source:
git clone https://github.com/yourusername/batch-finder.git
cd batch-finder
pip install -e .
🚀 Quick Start
Mode 1: input_shapes as tuple or list (single-input models)
Use a tuple or list of ints with -1 for the axis to maximize and numbers for fixed dimensions:
from batch_finder import find_max_minibatch
model = MyModel()
# Maximize axis 0, fix (64, 256)
max_val = find_max_minibatch(model=model, input_shapes=(-1, 64, 256))
# Maximize axis 2, fix (4, 8)
max_val = find_max_minibatch(model=model, input_shapes=(4, 8, -1))
# Multiple -1: same value for all variable axes
max_val = find_max_minibatch(model=model, input_shapes=(-1, 4, -1, 16))
Mode 2: axis_to_maximize + fixed_axis (multi-input models, e.g. HuggingFace)
from transformers import AutoModelForCausalLM
from batch_finder import find_max_minibatch
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
max_batch = find_max_minibatch(
model=model,
axis_to_maximize="batch_size",
fixed_axis={"seq_len": 32},
)
print(f"Max batch size: {max_batch}")
Mode 3: input_shapes (compact multi-tensor string)
For models whose forward takes several tensors, pass input_shapes as a string: one compact DSL with a shape tuple per argument (same order as forward), names for dimensions that must match across tensors, optional constraints between names, and exactly one searched dimension with name=-1.
Layout:
(dims...), (dims...), ... , constraint, constraint, ...
- Dimensions: non‑negative integers, or identifiers (
b,t, …). The same name in different positions ties those axes to the same size when materialized. - Search: one assignment must be
symbol=-1. Batch Finder binary‑searches that symbol (e.g. maximizeb). - Constraints:
name=rhswhererhscan be an integer, another name,coef*name,coefstuck to a name (e.g.1.5b), orname*name. Values that are not integers are rounded to the nearest integer for tensor sizes.
import torch
from batch_finder import find_max_minibatch
class MyModel(torch.nn.Module):
def forward(self, x, y):
# x: (23, b, t, 45), y: (b, t, 12) — example layout
...
model = MyModel()
max_b = find_max_minibatch(
model=model,
input_shapes="(23, b, t, 45),(b, t, 12), t=1.5b, b=-1",
)
# Searches b; sets t = round(1.5 * b) each trial. Returns max b (int).
Mutually exclusive with tuple/list single-tensor mode and with axis_to_maximize. Pass input_shapes= as a keyword argument.
Mode 3b: input_shapes as a dict (named arguments)
Same semantics as the string DSL, but each forward parameter is keyed by name. Values are shape strings, optionally followed by , int or , float for tensor dtype. Use the key "#constraints" for the constraint list (must include exactly one symbol=-1).
max_b = find_max_minibatch(
model=model,
input_shapes={
"input_ids": "(b, t), int",
"attention_mask": "(b, t), int",
"input_ids_encoder": "(d, b, t), int",
"attention_mask_encoder": "(d, b, t), int",
"labels": "(b, t)",
"#constraints": "t=2b, b=-1",
},
)
Import CONSTRAINTS_KEY from batch_finder (value "#constraints") if you prefer not to spell the string. FINDER_CONSTRAINTS_KEY is an alias for the same value.
Custom search parameters
max_val = find_max_minibatch(
model=model,
input_shapes=(-1, 128, 512),
initial_value=8,
n_attempts=30,
factor_down=3.0, # divide by 3 on failure
factor_up=2.0, # multiply by 2 on success
)
📖 API Reference
find_max_minibatch(model, ...)
Find the maximum value for the modifiable axis without OOM.
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
torch.nn.Module |
– | PyTorch or HuggingFace model |
input_shapes |
str | dict | Tuple[int, ...] | List[int] |
None |
String: multi-tensor DSL (Mode 3). Dict: named shapes + "#constraints" (Mode 3b). Tuple/list: single first forward arg, must include -1 (Mode 1). Mutually exclusive with axis_to_maximize |
axis_to_maximize |
str |
None |
Axis name when input_shapes is omitted, e.g. "batch_size" |
fixed_axis |
Dict[str, int] |
{} |
Fixed values, e.g. {"seq_len": 128} |
device |
torch.device |
auto | Device to run on |
delay |
float |
3.0 |
Seconds between attempts |
initial_value |
int |
512 |
First value to try |
n_attempts |
int |
50 |
Maximum attempts |
inference_only |
bool |
False |
If True, no gradients. If False, full forward+backward. |
factor_down |
float |
2.0 |
On failure: next = value / factor_down |
factor_up |
float |
2.0 |
On success: next = value * factor_up |
Returns: Tuple[int, ...] (when using tuple/list input_shapes), int (when using DSL string or axis_to_maximize), or None if nothing succeeded.
Modes:
- Provide
input_shapesas tuple/list: uses first input param with the given shape;-1= variable axis. - Provide
input_shapesas string: one shape tuple perforwardtensor; symbolic names +symbol=-1+ optional constraints. - Provide
input_shapesas dict: same as string DSL, with keys = parameter names and"#constraints"for constraints. - Provide
axis_to_maximize+fixed_axis: builds inputs from detected params and conventions.
Example output (axis_to_maximize + fixed_axis):
--- Detected inputs (type, estimated shape) ---
input_ids: integer, (32, 64)
attention_mask: integer, (32, 64)
---
batch_size fixed={'seq_len': 32}: 100%|████████████████████| 22/50 [01:26<00:00, 3.9s/it, gpus=1, i=22/50, max_ok=1919, min_fail=1920, status=✅, value=1919]
✅ Max value that passed: 1919
🔧 How It Works
- Input detection – Uses
inspect.signatureonmodel.forwardto find input names. - Type inference – Integer for
*ids,*mask,labels; float for others. - Shape estimation – From model (Linear.in_features, config, etc.) and param-name conventions.
- Search – On success: try
value * factor_up. On failure: tryvalue / factor_down. Stops when value 1 fails orn_attemptsreached. - Loss – Uses
output.lossif present, else sum of all output tensors.
⚠️ Important Notes
- Memory: Use conservative
initial_valueon limited GPU memory. - Time: Use
inference_only=Truefor faster runs. - Training: Use
inference_only=Falseto stress-test with backward pass. - Value 1: If the run fails at value 1, the function returns
None(no smaller value).
🤝 Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
📝 License
This project is licensed under the MIT License - see the LICENSE file for details.
Made with ❤️ for the PyTorch community
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file batch_finder-0.3.0.tar.gz.
File metadata
- Download URL: batch_finder-0.3.0.tar.gz
- Upload date:
- Size: 19.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
db372fe5e6f058a163d1f1244376f7a6dcf99c6fadc589e0b8938e35aada0aeb
|
|
| MD5 |
566bfa3a59d661745f9c845fbb408a45
|
|
| BLAKE2b-256 |
2bf69db35dc6f94807117c4c77abb45ac469039e8913914236138f3f6f3b20bf
|
File details
Details for the file batch_finder-0.3.0-py3-none-any.whl.
File metadata
- Download URL: batch_finder-0.3.0-py3-none-any.whl
- Upload date:
- Size: 16.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a4f7569f9603678d4f29919374236667fde29cf52194bdd2eb6cf81a0893002c
|
|
| MD5 |
09eeb8b0e7b3268703158b209986ce65
|
|
| BLAKE2b-256 |
9ba7b4880cee1e9e4f2862fb7ed5199697e888371f291e62186d4f8ea2d15d91
|