Lightweight PyTorch tensor diagnostics hooks for training loops
Project description
NN diagnostics
A useful tool to dump diagnostics info from checkpoint.
Install
pip install nndiagnostics
Quick Start
Dump diagnostics infomation
- Integrate diagnostics in your training loop
from diagnostics import maybe_attach_diagnostics
diag = maybe_attach_diagnostics(model)
for step, batch in enumerate(train_loader):
loss = train_step(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if diag and diag.should_stop(step, stop_after_steps=5):
diag.print(f"{args.exp_dir}/diagnostics-step-{step}.txt")
break
- Dump diagnostics information (by setting env
DUMP_DIAGNOSTICS)
DUMP_DIAGNOSTICS=1 python train.py
Add Inf/NaN Check Hooks
Register forward and backward hooks on every module and parameter to detect non-finite values (NaN/Inf) during training. When a non-finite value is detected, a warning is logged with the module/parameter name.
- Register inf check hooks to your model
from diagnostics import maybe_register_inf_check_hooks
maybe_register_inf_check_hooks(model)
for step, batch in enumerate(train_loader):
loss = train_step(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
- Enable by setting the environment variable
INF_CHECK:
INF_CHECK=1 python train.py
If any module output, gradient, or parameter gradient becomes non-finite, you will see warnings like:
WARNING: The sum of encoder.layers.3.output is not finite
WARNING: The sum of encoder.layers.3.grad[0] is not finite
WARNING: The sum of encoder.layers.3.weight.param_grad is not finite
CLI Tools
The package installs a diagnostics command with several subcommands for
post-processing diagnostics output. All tools support:
- Reading from a file or stdin (via pipe)
- Writing to a file with
-o/--output-file(default: stdout)
diagnostics show_infinite
Detect transitions to non-finite values (NaN/Inf). Prints lines containing "finite" that follow lines NOT containing "finite" -- highlighting where non-finite values first appear.
diagnostics show_infinite diagnostics.txt
cat diagnostics.txt | diagnostics show_infinite
diagnostics show_rms
Extract the RMS of each module's output from diagnostics text.
diagnostics show_rms diagnostics.txt | sort -gr -k2 | head
diagnostics show_eigs
Analyze eigenvalue statistics of module outputs. Computes ratios that indicate how concentrated the variance is across eigen-directions (next-largest-ratio, top_ratio, 2norm/1norm, etc.).
diagnostics show_eigs diagnostics.txt
diagnostics param_importance
Compute a normalized importance score for each parameter. Importance is defined
as value_mean * grad_mean * num_params, aggregated by module name prefixes
and suffixes, and normalized so all scores sum to 1.0.
Two-file mode compares the importance outputs of two files and prints the ratio.
# Single file: analyze importance
diagnostics param_importance diagnostics.txt | sort -gr -k2 | head
# Two files: compare importance
diagnostics param_importance diag_epoch5.txt diag_epoch10.txt
diagnostics param_magnitude
Extract the mean absolute value of each parameter from diagnostics text.
Two-file mode compares the magnitude outputs of two files and prints the ratio.
# Single file: extract magnitudes
diagnostics param_magnitude diagnostics.txt
# Two files: compare magnitudes
diagnostics param_magnitude diag_epoch5.txt diag_epoch10.txt
diagnostics compare_epochs
Compare model parameters between two PyTorch checkpoints. For each float32 parameter, computes the RMS norm and normalized relative difference.
With --summarize, additionally aggregates the diffs by module name prefixes
and suffixes.
# Compare two checkpoints
diagnostics compare_epochs exp/epoch-5.pt exp/epoch-10.pt
# With hierarchical summary
diagnostics compare_epochs exp/epoch-5.pt exp/epoch-10.pt --summarize
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file nndiagnostics-0.1.8.tar.gz.
File metadata
- Download URL: nndiagnostics-0.1.8.tar.gz
- Upload date:
- Size: 19.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.10.20
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9068ea7d1abf572d2d7ca61ffa3f729eadcb10ab414bfa537bf06bbee35a30a3
|
|
| MD5 |
822a68720a04fc9f4c04b64cb6a37398
|
|
| BLAKE2b-256 |
42b68938cea009dd2ea5756adcf4a3482a36390cb226b9033938ac248d30ec27
|
File details
Details for the file nndiagnostics-0.1.8-py3-none-any.whl.
File metadata
- Download URL: nndiagnostics-0.1.8-py3-none-any.whl
- Upload date:
- Size: 25.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.10.20
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
50478f8ecbc3e2e9faf07d55d7a3dae64e338c73f705f897a3749691f57e73de
|
|
| MD5 |
6e8702c3ce36a8be32461c5a82f7877c
|
|
| BLAKE2b-256 |
cfdc4fa0083b35db5966116a6bcdc7e41cc41ca534d55d76f72977c54b738f30
|