Training diagnostics for PyTorch: Gradient Noise Scale, Plasticity Probe, and probabilistic early stopping
Project description
traintools
Training diagnostics for PyTorch. Three tools, two lines of integration.
pip install traintools[full]
What it does
| Tool | Question it answers |
|---|---|
| GNS — Gradient Noise Scale | Is my batch size wasting compute? |
| PlasticityProbe | Is my network losing the ability to learn? |
| TrainGuard | Should I stop training yet? |
Quick start
from traintools.callbacks.pytorch import TraintoolsTracker
tracker = TraintoolsTracker(model, loss_fn)
for step, (x, y) in enumerate(dataloader):
loss = loss_fn(model(x), y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
decision = tracker.step(step=step, inputs=x, targets=y, val_loss=val_loss)
if decision and decision.should_stop:
break
HuggingFace Trainer:
from transformers import Trainer
from traintools.callbacks.huggingface import TraintoolsCallback
trainer = Trainer(model=model, ..., callbacks=[TraintoolsCallback()])
Gradient Noise Scale (GNS)
GNS measures the ratio of gradient signal to noise across your batch:
GNS = B * Var[g] / ||E[g]||^2
When GNS >> B your batch is too small — gradient quality improves with more data.
When GNS << B your batch is too large — you're wasting compute.
When GNS ≈ B you're at the efficient frontier.
Output:
[step 100] GNS=36.1 critical_batch=36 current=64 regime=optimal
> Batch size 64 is near the critical batch size (36). No change needed.
Reference: McCandlish et al. 2018, An Empirical Model of Large-Batch Training.
PlasticityProbe
Networks lose plasticity over long training runs or repeated fine-tuning. PlasticityProbe tracks three signals per layer:
- Dead neuron fraction — neurons that never activate on any batch sample
- Effective rank — how collapsed the weight matrix spectrum is
- Gradient/weight ratio — whether a layer is still receiving meaningful updates
Combined into a Plasticity Score ∈ [0, 1] (1 = fully plastic, 0 = dead).
[step 200] Plasticity Score: 0.968
All layers healthy.
Reference: Lyle et al. 2023, Understanding Plasticity in Neural Networks.
TrainGuard
Fits a power-law or exponential curve to your validation loss history, bootstraps uncertainty over the fit, and predicts whether continuing training is worth it:
[step 400] STOP
current loss: 0.6536
predicted final: 0.6119
expected improvement: 0.0417 (90% CI: [0.0012, 0.0821])
estimated plateau at step: 3200
reason: No improvement in 300 steps (best=0.6350 at step 93).
Installation
# Core (PyTorch only)
pip install traintools
# With curve fitting + plotting
pip install traintools[full]
# With HuggingFace Trainer integration
pip install traintools[hf]
Requirements
- Python >= 3.9
- PyTorch >= 2.0
- scipy, numpy, matplotlib (optional, for
[full]) - transformers >= 4.30 (optional, for
[hf])
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file traintools-0.1.0.tar.gz.
File metadata
- Download URL: traintools-0.1.0.tar.gz
- Upload date:
- Size: 16.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0e4c5c19ddf520aba928295f8df33364e8e53e515ca8accbee99a5348912a1e3
|
|
| MD5 |
781b93d5e912693d0237e8199e66f96f
|
|
| BLAKE2b-256 |
72eb30efaa10c5ebc0f70420ce91651cd8842d4cfa37b90498c6f2a00de44a2e
|
File details
Details for the file traintools-0.1.0-py3-none-any.whl.
File metadata
- Download URL: traintools-0.1.0-py3-none-any.whl
- Upload date:
- Size: 18.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3ab5369784e7fb009f2d6fc01133d25f847a6585bd4f6cee4e0b37ebd8b97518
|
|
| MD5 |
5b1c89f4605913dd7097d5187fcf89db
|
|
| BLAKE2b-256 |
cb1ddc29a40d830b65d3297a87d054fcf6a7beda6aa6ecf492a9a259821ace5d
|