Skip to main content

Training diagnostics for PyTorch: Gradient Noise Scale, Plasticity Probe, and probabilistic early stopping

Project description

traintools

Training diagnostics for PyTorch. Three tools, two lines of integration.

pip install traintools[full]

What it does

Tool Question it answers
GNS — Gradient Noise Scale Is my batch size wasting compute?
PlasticityProbe Is my network losing the ability to learn?
TrainGuard Should I stop training yet?

Quick start

from traintools.callbacks.pytorch import TraintoolsTracker

tracker = TraintoolsTracker(model, loss_fn)

for step, (x, y) in enumerate(dataloader):
    loss = loss_fn(model(x), y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    decision = tracker.step(step=step, inputs=x, targets=y, val_loss=val_loss)
    if decision and decision.should_stop:
        break

HuggingFace Trainer:

from transformers import Trainer
from traintools.callbacks.huggingface import TraintoolsCallback

trainer = Trainer(model=model, ..., callbacks=[TraintoolsCallback()])

Gradient Noise Scale (GNS)

GNS measures the ratio of gradient signal to noise across your batch:

GNS = B * Var[g] / ||E[g]||^2

When GNS >> B your batch is too small — gradient quality improves with more data.
When GNS << B your batch is too large — you're wasting compute.
When GNS ≈ B you're at the efficient frontier.

Output:

[step 100] GNS=36.1  critical_batch=36  current=64  regime=optimal
  > Batch size 64 is near the critical batch size (36). No change needed.

Reference: McCandlish et al. 2018, An Empirical Model of Large-Batch Training.

PlasticityProbe

Networks lose plasticity over long training runs or repeated fine-tuning. PlasticityProbe tracks three signals per layer:

  • Dead neuron fraction — neurons that never activate on any batch sample
  • Effective rank — how collapsed the weight matrix spectrum is
  • Gradient/weight ratio — whether a layer is still receiving meaningful updates

Combined into a Plasticity Score ∈ [0, 1] (1 = fully plastic, 0 = dead).

[step 200] Plasticity Score: 0.968
  All layers healthy.

Reference: Lyle et al. 2023, Understanding Plasticity in Neural Networks.

TrainGuard

Fits a power-law or exponential curve to your validation loss history, bootstraps uncertainty over the fit, and predicts whether continuing training is worth it:

[step 400] STOP
  current loss: 0.6536
  predicted final: 0.6119
  expected improvement: 0.0417 (90% CI: [0.0012, 0.0821])
  estimated plateau at step: 3200
  reason: No improvement in 300 steps (best=0.6350 at step 93).

Installation

# Core (PyTorch only)
pip install traintools

# With curve fitting + plotting
pip install traintools[full]

# With HuggingFace Trainer integration
pip install traintools[hf]

Requirements

  • Python >= 3.9
  • PyTorch >= 2.0
  • scipy, numpy, matplotlib (optional, for [full])
  • transformers >= 4.30 (optional, for [hf])

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

traintools-0.1.0.tar.gz (16.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

traintools-0.1.0-py3-none-any.whl (18.5 kB view details)

Uploaded Python 3

File details

Details for the file traintools-0.1.0.tar.gz.

File metadata

  • Download URL: traintools-0.1.0.tar.gz
  • Upload date:
  • Size: 16.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for traintools-0.1.0.tar.gz
Algorithm Hash digest
SHA256 0e4c5c19ddf520aba928295f8df33364e8e53e515ca8accbee99a5348912a1e3
MD5 781b93d5e912693d0237e8199e66f96f
BLAKE2b-256 72eb30efaa10c5ebc0f70420ce91651cd8842d4cfa37b90498c6f2a00de44a2e

See more details on using hashes here.

File details

Details for the file traintools-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: traintools-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 18.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for traintools-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3ab5369784e7fb009f2d6fc01133d25f847a6585bd4f6cee4e0b37ebd8b97518
MD5 5b1c89f4605913dd7097d5187fcf89db
BLAKE2b-256 cb1ddc29a40d830b65d3297a87d054fcf6a7beda6aa6ecf492a9a259821ace5d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page