Model-free optimal batch size finder using susceptibility analysis
Project description
batch-susceptibility
Model-free optimal batch size finder using susceptibility analysis.
Stop guessing your batch size. This tool uses a physics-inspired method to find the critical batch size where your training dynamics undergo a phase transition — the point where batch statistics change from noisy (too small) to information-losing (too large).
from batch_susceptibility.pytorch import find_optimal_batch_size
result = find_optimal_batch_size(model, dataset, loss_fn)
print(f"Use batch size: {result.K_c}") # e.g., 128
How It Works
For a sequence of per-step losses L_1, L_2, ..., L_N:
- Re-batch at multiple scales
K: group into batches of sizeKand compute batch means - Measure
V(K) = Var(batch means)— how variable are the batch statistics? - Differentiate:
chi(K) = |d log V / d log K|— the susceptibility - Find the peak:
K_c = argmax(chi)— the critical batch size
For i.i.d. data, V(K) ~ 1/K (law of large numbers). Deviations from this scaling reveal temporal structure in your training dynamics, and the peak of chi(K) marks the characteristic correlation scale.
Key metrics:
K_c: Optimal batch size (the scale where information density peaks)kappa: Peak sharpness (max(chi)/median(chi)). Higher = more confident.>3is significant.alpha: Scaling exponent.-1= i.i.d.,> -1= correlated (drift),< -1= anti-correlated
Based on: M.C. Wurm, "Batch-Size Susceptibility across Five Computational Domains" (2024/2025).
Installation
# Core (numpy + scipy only)
pip install batch-susceptibility
# With PyTorch integration
pip install batch-susceptibility[pytorch]
# With TensorFlow/Keras
pip install batch-susceptibility[tensorflow]
# With PyTorch Lightning
pip install batch-susceptibility[lightning]
# With plotting
pip install batch-susceptibility[plot]
# Everything
pip install batch-susceptibility[all]
Quick Start
PyTorch: Find Optimal Batch Size
import torch.nn as nn
from torchvision import datasets, transforms, models
from batch_susceptibility.pytorch import find_optimal_batch_size
dataset = datasets.CIFAR10(root="./data", train=True, download=True,
transform=transforms.ToTensor())
model = models.resnet18(num_classes=10)
result = find_optimal_batch_size(
model, dataset, nn.CrossEntropyLoss(),
batch_sizes=[16, 32, 64, 128, 256, 512, 1024],
steps_per_size=100,
)
print(f"Optimal batch size: {result.K_c}")
PyTorch: Monitor During Training
from batch_susceptibility.pytorch import SusceptibilityCallback
monitor = SusceptibilityCallback(check_every=500)
for epoch in range(10):
for batch in dataloader:
loss = train_step(batch)
monitor.on_step(loss)
# K_c is logged every 500 steps
final = monitor.result()
print(f"K_c = {final.K_c}, regime = {final.regime}")
TensorFlow / Keras
from batch_susceptibility.tensorflow import find_optimal_batch_size
result = find_optimal_batch_size(model, x_train, y_train)
print(f"Optimal batch size: {result.K_c}")
Or as a Keras callback:
from batch_susceptibility.tensorflow import SusceptibilityCallback
cb = SusceptibilityCallback(check_every=500)
model.fit(x_train, y_train, epochs=10, callbacks=[cb])
print(cb.result().K_c)
PyTorch Lightning
from batch_susceptibility.lightning import SusceptibilityCallback
cb = SusceptibilityCallback(check_every=500)
trainer = pl.Trainer(callbacks=[cb])
trainer.fit(model, train_dataloader)
print(cb.result().K_c)
HuggingFace Transformers
from batch_susceptibility.examples.huggingface_transformers import SusceptibilityTrainerCallback
cb = SusceptibilityTrainerCallback(check_every=200)
trainer = Trainer(model=model, args=args, callbacks=[cb], ...)
trainer.train()
print(cb.result().K_c)
Generic Data (any framework)
from batch_susceptibility import BatchSusceptibility
# From per-step losses (any source)
bs = BatchSusceptibility()
bs.feed(losses)
result = bs.find_critical()
print(result.summary())
Command Line
# Analyze a CSV of per-step losses
batch-susceptibility training_log.csv --column loss
# Analyze a batch-size sweep
batch-susceptibility sweep.csv --mode sweep --batch-col batch_size --metric-col loss
# JSON output
batch-susceptibility losses.csv -c loss --json
# With plot
batch-susceptibility losses.csv -c loss --plot result.png
# From stdin
cat losses.txt | batch-susceptibility -
Interpreting Results
Critical batch size: K_c = 128
Sharpness: kappa = 5.42
Scaling exponent: alpha = -0.87 +/- 0.02
i.i.d. test: p = 0.0001
Regime: weakly-correlated
Significant: True
| Metric | Meaning |
|---|---|
K_c |
Optimal batch size. Use this for training. |
kappa > 3 |
Confident detection. There IS a characteristic scale. |
kappa ~ 1 |
No characteristic scale found. Data is scale-free. |
alpha ~ -1 |
i.i.d. — no temporal correlations (normal). |
alpha > -1 |
Positive correlations or drift (loss landscape structure). |
alpha < -1 |
Anti-correlations (rare, indicates over-mixing). |
Regime interpretation:
iid: Your data has no temporal structure. Any batch size works equally well.correlated: Training has drift/momentum.K_cmarks where correlations decay.weakly-correlated: Mild structure.K_cis a soft optimum.anti-correlated: Unusual. Check for data preprocessing artifacts.
Visualization
from batch_susceptibility.plot import plot_susceptibility, plot_comparison
# Single result
fig = plot_susceptibility(result, title="My Model")
fig.savefig("susceptibility.png")
# Compare multiple experiments
fig = plot_comparison({
"Adam lr=1e-3": result_adam,
"SGD lr=0.1": result_sgd,
})
fig.savefig("comparison.png")
API Reference
BatchSusceptibility
Main interface class.
bs = BatchSusceptibility(K_min=2, K_max=None, n_points=60, min_batches=10)
bs.feed(values) # Feed 1D sequence of metric values
result = bs.find_critical() # Compute susceptibility, find K_c
SusceptibilityResult
Dataclass with all results.
| Attribute | Type | Description |
|---|---|---|
K_c |
float |
Critical batch size |
kappa |
float |
Peak sharpness |
alpha |
float |
Scaling exponent |
alpha_se |
float |
Standard error of alpha |
p_iid |
float |
p-value for i.i.d. null |
K_values |
ndarray |
Batch sizes tested |
V_values |
ndarray |
Variance at each K |
chi_values |
ndarray |
Susceptibility at each K |
is_significant |
bool |
kappa > 3 and p < 0.05 |
regime |
str |
iid / correlated / weakly-correlated / anti-correlated |
Framework integrations
| Module | Class/Function | Description |
|---|---|---|
batch_susceptibility.pytorch |
find_optimal_batch_size() |
One-call PyTorch finder |
batch_susceptibility.pytorch |
BatchSizeFinder |
Configurable PyTorch finder |
batch_susceptibility.pytorch |
SusceptibilityCallback |
Training monitor |
batch_susceptibility.tensorflow |
find_optimal_batch_size() |
One-call Keras finder |
batch_susceptibility.tensorflow |
SusceptibilityCallback |
Keras callback |
batch_susceptibility.lightning |
SusceptibilityCallback |
Lightning callback |
batch_susceptibility.lightning |
BatchSizeFinder |
Lightning finder |
batch_susceptibility.plot |
plot_susceptibility() |
Single result plot |
batch_susceptibility.plot |
plot_comparison() |
Multi-result comparison |
batch_susceptibility.cli |
main() |
CLI entry point |
Examples
See the examples/ directory:
pytorch_cifar10.py— CIFAR-10 with ResNet-18pytorch_monitor.py— Online monitoring during trainingtensorflow_mnist.py— MNIST with Kerashuggingface_transformers.py— HuggingFace Trainer integrationgeneric_csv.py— Any data source / CSV
Theory
The method is based on the observation that for any stationary time series, the variance of batch means V(K) follows a power law V(K) ~ K^alpha where:
alpha = -1for independent, identically distributed (i.i.d.) samplesalpha > -1when consecutive samples are positively correlatedalpha < -1when consecutive samples are anti-correlated
The susceptibility chi(K) = |d log V / d log K| measures how sensitive the variance scaling is to the batch size. A peak in chi(K) indicates a critical scale — the batch size at which the dominant correlation structure is resolved.
For ML training: This critical scale corresponds to the batch size where:
- Smaller batches: gradient estimates are dominated by noise correlations
- Larger batches: diminishing returns — averaging out signal, not just noise
- At
K_c: optimal information density per gradient step
This connects to the critical batch size concept from McCandlish et al. (2018) but requires no gradient computation — only loss values.
License
Dual-licensed under:
- AGPL-3.0 for open-source / non-commercial use (license_AGPL.txt)
- Commercial license for proprietary integration (license_COMMERCIAL.txt)
See LICENSE.txt for full details.
For commercial licensing, contact: nfo@forgottenforge.xyz
Related
- sigmacore — The general sigma_c framework for critical scale detection across domains
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file batch_susceptibility-1.0.0.tar.gz.
File metadata
- Download URL: batch_susceptibility-1.0.0.tar.gz
- Upload date:
- Size: 26.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f8861ec3e41cf35a07c9aa0bf57f0a46e069211f7037252ea4bf557e92db7517
|
|
| MD5 |
22b27fcf24e9298127147e44557ff02f
|
|
| BLAKE2b-256 |
ccf1f8987f04983b1c626d8523bae0f987950fd105ed7b77002ccc3d4a57b764
|
File details
Details for the file batch_susceptibility-1.0.0-py3-none-any.whl.
File metadata
- Download URL: batch_susceptibility-1.0.0-py3-none-any.whl
- Upload date:
- Size: 24.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9b8c273fe61dc668b36a9ec0a2166908b10c84539c8b7d7538166699e1958474
|
|
| MD5 |
8eb4181f39745dbce3c5ad96b8f5d749
|
|
| BLAKE2b-256 |
3fc209c141b5a816ab7a26af6f12ac0518fb4011e9b777ab5cae20459c3fca84
|