Training diagnostics for PyTorch: Gradient Noise Scale (free during gradient accumulation), Plasticity Probe, and probabilistic early stopping
Project description
traintools
Training diagnostics for PyTorch. Three tools, two lines of integration.
pip install traintools[full]
What it does
| Tool | Question it answers |
|---|---|
| GNS — Gradient Noise Scale | Is my batch size wasting compute? |
| PlasticityProbe | Is my network losing the ability to learn? |
| TrainGuard | Should I stop training yet? |
Quick start
from traintools.callbacks.pytorch import TraintoolsTracker
tracker = TraintoolsTracker(model, loss_fn)
for step, (x, y) in enumerate(dataloader):
loss = loss_fn(model(x), y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
decision = tracker.step(step=step, inputs=x, targets=y, val_loss=val_loss)
if decision and decision.should_stop:
break
HuggingFace Trainer:
from transformers import Trainer
from traintools.callbacks.huggingface import TraintoolsCallback
trainer = Trainer(model=model, ..., callbacks=[TraintoolsCallback()])
Gradient Noise Scale (GNS)
GNS is the ratio of per-example gradient variance to gradient signal:
GNS = tr(Σ) / ||G||^2
It equals the critical batch size B* — the point of diminishing returns from larger batches.
GNS > B→ under-batched: gradient too noisy, larger batches helpGNS < B→ over-batched: batch larger than needed, shrink it and save computeGNS ≈ B→ optimal
traintools uses the unbiased estimators from McCandlish et al. 2018
(Bessel-corrected variance, bias-corrected signal) and tracks GNS as the ratio
of two separate exponential moving averages — the stable estimator the paper
recommends. (Naive single-shot estimates are biased low by ~2x and far too
noisy to act on.)
[step 500] GNS=5010.7 (EMA) critical_batch=5011 current=64 regime=under-batched
> Batch size 64 is ~78x below the critical batch (~5011). Larger batches would give cleaner gradients per step.
Free GNS during gradient accumulation
If you already use gradient accumulation, GNS costs zero extra forward/backward passes — the per-micro-batch gradients you compute anyway are exactly the samples GNS needs. Every other GNS implementation pays for extra passes.
from traintools import GradientAccumulationGNS
gns = GradientAccumulationGNS(model, micro_batch_size=B_micro)
for step in range(num_steps):
for micro in micro_batches:
(loss_fn(model(xm), ym) / accum_steps).backward()
gns.record_microbatch() # after each micro-batch backward
optimizer.step()
result = gns.compute(step=step) # GNSResult, free
optimizer.zero_grad()
gns.reset_accumulation()
Reference: McCandlish et al. 2018, An Empirical Model of Large-Batch Training.
PlasticityProbe
Networks lose plasticity over long training runs or repeated fine-tuning. The failure shows up in the representations, so PlasticityProbe measures the activations directly (not weight matrices), matching the operational definitions in the loss-of-plasticity literature:
- Dormant unit fraction — units whose activation is ~0 for every input, attributed to the activation module that produced them
- Feature effective rank — effective rank of the activation covariance, normalised to [0,1]; low rank = representational collapse
Combined into a Plasticity Score ∈ [0, 1] (1 = fully plastic, 0 = dead).
[step 200] Plasticity Score: 0.706
All layers healthy.
References: Dohare & Sutton et al. 2024 (Nature 632:768), Loss of plasticity in deep continual learning; Lyle et al. 2023, Understanding Plasticity in Neural Networks.
TrainGuard
Fits a power-law or exponential curve to your validation loss history, bootstraps uncertainty over the fit, and predicts whether continuing training is worth it:
[step 400] STOP
current loss: 0.6536
predicted final: 0.6119
expected improvement: 0.0417 (90% CI: [0.0012, 0.0821])
estimated plateau at step: 3200
reason: No improvement in 300 steps (best=0.6350 at step 93).
Installation
# Core (PyTorch only)
pip install traintools
# With curve fitting + plotting
pip install traintools[full]
# With HuggingFace Trainer integration
pip install traintools[hf]
Requirements
- Python >= 3.9
- PyTorch >= 2.0
- scipy, numpy, matplotlib (optional, for
[full]) - transformers >= 4.30 (optional, for
[hf])
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file traintools-0.2.0.tar.gz.
File metadata
- Download URL: traintools-0.2.0.tar.gz
- Upload date:
- Size: 29.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
841b9c32edd3f347dbb9bc4f761139c133d09531d3c5222dedaeea44e7e0715c
|
|
| MD5 |
cfbefa28eae257582e9562dfe9a5652d
|
|
| BLAKE2b-256 |
d3c19521d916e2328bdf94721318d649dda1eb30a8c5ad1f2def5131a0c37c87
|
File details
Details for the file traintools-0.2.0-py3-none-any.whl.
File metadata
- Download URL: traintools-0.2.0-py3-none-any.whl
- Upload date:
- Size: 24.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b47079d77a0026510310532cebd778f6e08428b9425246be98da2394b2c1f18c
|
|
| MD5 |
b7eba64aced84b255797ad49830d76cd
|
|
| BLAKE2b-256 |
d291aaf362b262c48953a4ec94d4d24167b9dbba17bcec861ac5e2943cbda92e
|