Skip to main content

Model-free optimal batch size finder using susceptibility analysis

Project description

batch-susceptibility

Model-free optimal batch size finder using susceptibility analysis.

Stop guessing your batch size. This tool uses a physics-inspired method to find the critical batch size where your training dynamics undergo a phase transition — the point where batch statistics change from noisy (too small) to information-losing (too large).

from batch_susceptibility.pytorch import find_optimal_batch_size

result = find_optimal_batch_size(model, dataset, loss_fn)
print(f"Use batch size: {result.K_c}")  # e.g., 128

How It Works

For a sequence of per-step losses L_1, L_2, ..., L_N:

  1. Re-batch at multiple scales K: group into batches of size K and compute batch means
  2. Measure V(K) = Var(batch means) — how variable are the batch statistics?
  3. Differentiate: chi(K) = |d log V / d log K| — the susceptibility
  4. Find the peak: K_c = argmax(chi) — the critical batch size

For i.i.d. data, V(K) ~ 1/K (law of large numbers). Deviations from this scaling reveal temporal structure in your training dynamics, and the peak of chi(K) marks the characteristic correlation scale.

Key metrics:

  • K_c: Optimal batch size (the scale where information density peaks)
  • kappa: Peak sharpness (max(chi)/median(chi)). Higher = more confident. >3 is significant.
  • alpha: Scaling exponent. -1 = i.i.d., > -1 = correlated (drift), < -1 = anti-correlated

Based on: M.C. Wurm, "Batch-Size Susceptibility across Five Computational Domains" (2024/2025).

Installation

# Core (numpy + scipy only)
pip install batch-susceptibility

# With PyTorch integration
pip install batch-susceptibility[pytorch]

# With TensorFlow/Keras
pip install batch-susceptibility[tensorflow]

# With PyTorch Lightning
pip install batch-susceptibility[lightning]

# With plotting
pip install batch-susceptibility[plot]

# Everything
pip install batch-susceptibility[all]

Quick Start

PyTorch: Find Optimal Batch Size

import torch.nn as nn
from torchvision import datasets, transforms, models
from batch_susceptibility.pytorch import find_optimal_batch_size

dataset = datasets.CIFAR10(root="./data", train=True, download=True,
                           transform=transforms.ToTensor())
model = models.resnet18(num_classes=10)

result = find_optimal_batch_size(
    model, dataset, nn.CrossEntropyLoss(),
    batch_sizes=[16, 32, 64, 128, 256, 512, 1024],
    steps_per_size=100,
)
print(f"Optimal batch size: {result.K_c}")

PyTorch: Monitor During Training

from batch_susceptibility.pytorch import SusceptibilityCallback

monitor = SusceptibilityCallback(check_every=500)

for epoch in range(10):
    for batch in dataloader:
        loss = train_step(batch)
        monitor.on_step(loss)

# K_c is logged every 500 steps
final = monitor.result()
print(f"K_c = {final.K_c}, regime = {final.regime}")

TensorFlow / Keras

from batch_susceptibility.tensorflow import find_optimal_batch_size

result = find_optimal_batch_size(model, x_train, y_train)
print(f"Optimal batch size: {result.K_c}")

Or as a Keras callback:

from batch_susceptibility.tensorflow import SusceptibilityCallback

cb = SusceptibilityCallback(check_every=500)
model.fit(x_train, y_train, epochs=10, callbacks=[cb])
print(cb.result().K_c)

PyTorch Lightning

from batch_susceptibility.lightning import SusceptibilityCallback

cb = SusceptibilityCallback(check_every=500)
trainer = pl.Trainer(callbacks=[cb])
trainer.fit(model, train_dataloader)
print(cb.result().K_c)

HuggingFace Transformers

from batch_susceptibility.examples.huggingface_transformers import SusceptibilityTrainerCallback

cb = SusceptibilityTrainerCallback(check_every=200)
trainer = Trainer(model=model, args=args, callbacks=[cb], ...)
trainer.train()
print(cb.result().K_c)

Generic Data (any framework)

from batch_susceptibility import BatchSusceptibility

# From per-step losses (any source)
bs = BatchSusceptibility()
bs.feed(losses)
result = bs.find_critical()
print(result.summary())

Command Line

# Analyze a CSV of per-step losses
batch-susceptibility training_log.csv --column loss

# Analyze a batch-size sweep
batch-susceptibility sweep.csv --mode sweep --batch-col batch_size --metric-col loss

# JSON output
batch-susceptibility losses.csv -c loss --json

# With plot
batch-susceptibility losses.csv -c loss --plot result.png

# From stdin
cat losses.txt | batch-susceptibility -

Interpreting Results

Critical batch size:  K_c = 128
Sharpness:            kappa = 5.42
Scaling exponent:     alpha = -0.87 +/- 0.02
i.i.d. test:          p = 0.0001
Regime:               weakly-correlated
Significant:          True
Metric Meaning
K_c Optimal batch size. Use this for training.
kappa > 3 Confident detection. There IS a characteristic scale.
kappa ~ 1 No characteristic scale found. Data is scale-free.
alpha ~ -1 i.i.d. — no temporal correlations (normal).
alpha > -1 Positive correlations or drift (loss landscape structure).
alpha < -1 Anti-correlations (rare, indicates over-mixing).

Regime interpretation:

  • iid: Your data has no temporal structure. Any batch size works equally well.
  • correlated: Training has drift/momentum. K_c marks where correlations decay.
  • weakly-correlated: Mild structure. K_c is a soft optimum.
  • anti-correlated: Unusual. Check for data preprocessing artifacts.

Visualization

from batch_susceptibility.plot import plot_susceptibility, plot_comparison

# Single result
fig = plot_susceptibility(result, title="My Model")
fig.savefig("susceptibility.png")

# Compare multiple experiments
fig = plot_comparison({
    "Adam lr=1e-3": result_adam,
    "SGD lr=0.1": result_sgd,
})
fig.savefig("comparison.png")

API Reference

BatchSusceptibility

Main interface class.

bs = BatchSusceptibility(K_min=2, K_max=None, n_points=60, min_batches=10)
bs.feed(values)             # Feed 1D sequence of metric values
result = bs.find_critical() # Compute susceptibility, find K_c

SusceptibilityResult

Dataclass with all results.

Attribute Type Description
K_c float Critical batch size
kappa float Peak sharpness
alpha float Scaling exponent
alpha_se float Standard error of alpha
p_iid float p-value for i.i.d. null
K_values ndarray Batch sizes tested
V_values ndarray Variance at each K
chi_values ndarray Susceptibility at each K
is_significant bool kappa > 3 and p < 0.05
regime str iid / correlated / weakly-correlated / anti-correlated

Framework integrations

Module Class/Function Description
batch_susceptibility.pytorch find_optimal_batch_size() One-call PyTorch finder
batch_susceptibility.pytorch BatchSizeFinder Configurable PyTorch finder
batch_susceptibility.pytorch SusceptibilityCallback Training monitor
batch_susceptibility.tensorflow find_optimal_batch_size() One-call Keras finder
batch_susceptibility.tensorflow SusceptibilityCallback Keras callback
batch_susceptibility.lightning SusceptibilityCallback Lightning callback
batch_susceptibility.lightning BatchSizeFinder Lightning finder
batch_susceptibility.plot plot_susceptibility() Single result plot
batch_susceptibility.plot plot_comparison() Multi-result comparison
batch_susceptibility.cli main() CLI entry point

Examples

See the examples/ directory:

Theory

The method is based on the observation that for any stationary time series, the variance of batch means V(K) follows a power law V(K) ~ K^alpha where:

  • alpha = -1 for independent, identically distributed (i.i.d.) samples
  • alpha > -1 when consecutive samples are positively correlated
  • alpha < -1 when consecutive samples are anti-correlated

The susceptibility chi(K) = |d log V / d log K| measures how sensitive the variance scaling is to the batch size. A peak in chi(K) indicates a critical scale — the batch size at which the dominant correlation structure is resolved.

For ML training: This critical scale corresponds to the batch size where:

  • Smaller batches: gradient estimates are dominated by noise correlations
  • Larger batches: diminishing returns — averaging out signal, not just noise
  • At K_c: optimal information density per gradient step

This connects to the critical batch size concept from McCandlish et al. (2018) but requires no gradient computation — only loss values.

License

Dual-licensed under:

See LICENSE for full details.

For commercial licensing, contact: nfo@forgottenforge.xyz

Related

  • sigmacore — The general sigma_c framework for critical scale detection across domains

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

batch_susceptibility-1.0.1.tar.gz (26.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

batch_susceptibility-1.0.1-py3-none-any.whl (24.0 kB view details)

Uploaded Python 3

File details

Details for the file batch_susceptibility-1.0.1.tar.gz.

File metadata

  • Download URL: batch_susceptibility-1.0.1.tar.gz
  • Upload date:
  • Size: 26.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.2

File hashes

Hashes for batch_susceptibility-1.0.1.tar.gz
Algorithm Hash digest
SHA256 a63dcf957e5b17d26eae2642bb85455e651f7b2b5abaddf0e8c289374c893b4b
MD5 4a7200d9bdf7f20e9293cab4be953ba9
BLAKE2b-256 fa155cc34108c99de14006bb49e8c8b429514caf77ad722280fe73624595941c

See more details on using hashes here.

File details

Details for the file batch_susceptibility-1.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for batch_susceptibility-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 2c0bc2769bd0b6076d237cd9aeb944416d90f624e276ab13d9b4b3cda6f27f7e
MD5 bc13466497e7e7d34c9180822a938934
BLAKE2b-256 1e522470020ff088fae1d8b775a330c26df1c391a7c0ad75f47fa26aca828980

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page