Skip to main content

Training diagnostics for PyTorch: Gradient Noise Scale (free during gradient accumulation), Plasticity Probe, and probabilistic early stopping

Project description

traintools

Training diagnostics for PyTorch. Three tools, two lines of integration.

pip install traintools[full]

What it does

Tool Question it answers
GNS — Gradient Noise Scale Is my batch size wasting compute?
PlasticityProbe Is my network losing the ability to learn?
TrainGuard Should I stop training yet?

Quick start

from traintools.callbacks.pytorch import TraintoolsTracker

tracker = TraintoolsTracker(model, loss_fn)

for step, (x, y) in enumerate(dataloader):
    loss = loss_fn(model(x), y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    decision = tracker.step(step=step, inputs=x, targets=y, val_loss=val_loss)
    if decision and decision.should_stop:
        break

HuggingFace Trainer:

from transformers import Trainer
from traintools.callbacks.huggingface import TraintoolsCallback

trainer = Trainer(model=model, ..., callbacks=[TraintoolsCallback()])

Gradient Noise Scale (GNS)

GNS is the ratio of per-example gradient variance to gradient signal:

GNS = tr(Σ) / ||G||^2

It equals the critical batch size B* — the point of diminishing returns from larger batches.

  • GNS > Bunder-batched: gradient too noisy, larger batches help
  • GNS < Bover-batched: batch larger than needed, shrink it and save compute
  • GNS ≈ Boptimal

traintools uses the unbiased estimators from McCandlish et al. 2018 (Bessel-corrected variance, bias-corrected signal) and tracks GNS as the ratio of two separate exponential moving averages — the stable estimator the paper recommends. (Naive single-shot estimates are biased low by ~2x and far too noisy to act on.)

[step 500] GNS=5010.7 (EMA)  critical_batch=5011  current=64  regime=under-batched
  > Batch size 64 is ~78x below the critical batch (~5011). Larger batches would give cleaner gradients per step.

Free GNS during gradient accumulation

If you already use gradient accumulation, GNS costs zero extra forward/backward passes — the per-micro-batch gradients you compute anyway are exactly the samples GNS needs. Every other GNS implementation pays for extra passes.

from traintools import GradientAccumulationGNS

gns = GradientAccumulationGNS(model, micro_batch_size=B_micro)

for step in range(num_steps):
    for micro in micro_batches:
        (loss_fn(model(xm), ym) / accum_steps).backward()
        gns.record_microbatch()          # after each micro-batch backward
    optimizer.step()
    result = gns.compute(step=step)      # GNSResult, free
    optimizer.zero_grad()
    gns.reset_accumulation()

Reference: McCandlish et al. 2018, An Empirical Model of Large-Batch Training.

PlasticityProbe

Networks lose plasticity over long training runs or repeated fine-tuning. The failure shows up in the representations, so PlasticityProbe measures the activations directly (not weight matrices), matching the operational definitions in the loss-of-plasticity literature:

  • Dormant unit fraction — units whose activation is ~0 for every input, attributed to the activation module that produced them
  • Feature effective rank — effective rank of the activation covariance, normalised to [0,1]; low rank = representational collapse

Combined into a Plasticity Score ∈ [0, 1] (1 = fully plastic, 0 = dead).

[step 200] Plasticity Score: 0.706
  All layers healthy.

References: Dohare & Sutton et al. 2024 (Nature 632:768), Loss of plasticity in deep continual learning; Lyle et al. 2023, Understanding Plasticity in Neural Networks.

TrainGuard

Fits a power-law or exponential curve to your validation loss history, bootstraps uncertainty over the fit, and predicts whether continuing training is worth it:

[step 400] STOP
  current loss: 0.6536
  predicted final: 0.6119
  expected improvement: 0.0417 (90% CI: [0.0012, 0.0821])
  estimated plateau at step: 3200
  reason: No improvement in 300 steps (best=0.6350 at step 93).

Installation

# Core (PyTorch only)
pip install traintools

# With curve fitting + plotting
pip install traintools[full]

# With HuggingFace Trainer integration
pip install traintools[hf]

Requirements

  • Python >= 3.9
  • PyTorch >= 2.0
  • scipy, numpy, matplotlib (optional, for [full])
  • transformers >= 4.30 (optional, for [hf])

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

traintools-0.2.0.tar.gz (29.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

traintools-0.2.0-py3-none-any.whl (24.1 kB view details)

Uploaded Python 3

File details

Details for the file traintools-0.2.0.tar.gz.

File metadata

  • Download URL: traintools-0.2.0.tar.gz
  • Upload date:
  • Size: 29.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for traintools-0.2.0.tar.gz
Algorithm Hash digest
SHA256 841b9c32edd3f347dbb9bc4f761139c133d09531d3c5222dedaeea44e7e0715c
MD5 cfbefa28eae257582e9562dfe9a5652d
BLAKE2b-256 d3c19521d916e2328bdf94721318d649dda1eb30a8c5ad1f2def5131a0c37c87

See more details on using hashes here.

File details

Details for the file traintools-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: traintools-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 24.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for traintools-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b47079d77a0026510310532cebd778f6e08428b9425246be98da2394b2c1f18c
MD5 b7eba64aced84b255797ad49830d76cd
BLAKE2b-256 d291aaf362b262c48953a4ec94d4d24167b9dbba17bcec861ac5e2943cbda92e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page