Skip to main content

Lightweight PyTorch tensor diagnostics hooks for training loops

Project description

NN diagnostics

A useful tool to dump diagnostics info from checkpoint.

Install

pip install nndiagnostics

Quick Start

Dump diagnostics infomation

  1. Integrate diagnostics in your training loop
from diagnostics import maybe_attach_diagnostics

diag = maybe_attach_diagnostics(model)

for step, batch in enumerate(train_loader):
    loss = train_step(batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if diag and diag.should_stop(step, stop_after_steps=5):
        diag.print(f"{args.exp_dir}/diagnostics-step-{step}.txt")
        break
  1. Dump diagnostics information (by setting env DUMP_DIAGNOSTICS)
DUMP_DIAGNOSTICS=1 python train.py

Add Inf/NaN Check Hooks

Register forward and backward hooks on every module and parameter to detect non-finite values (NaN/Inf) during training. When a non-finite value is detected, a warning is logged with the module/parameter name.

  1. Register inf check hooks to your model
from diagnostics import maybe_register_inf_check_hooks

maybe_register_inf_check_hooks(model)

for step, batch in enumerate(train_loader):
    loss = train_step(batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
  1. Enable by setting the environment variable INF_CHECK:
INF_CHECK=1 python train.py

If any module output, gradient, or parameter gradient becomes non-finite, you will see warnings like:

WARNING: The sum of encoder.layers.3.output is not finite
WARNING: The sum of encoder.layers.3.grad[0] is not finite
WARNING: The sum of encoder.layers.3.weight.param_grad is not finite

CLI Tools

The package installs a diagnostics command with several subcommands for post-processing diagnostics output. All tools support:

  • Reading from a file or stdin (via pipe)
  • Writing to a file with -o/--output-file (default: stdout)

diagnostics show_infinite

Detect transitions to non-finite values (NaN/Inf). Prints lines containing "finite" that follow lines NOT containing "finite" -- highlighting where non-finite values first appear.

diagnostics show_infinite diagnostics.txt
cat diagnostics.txt | diagnostics show_infinite

diagnostics show_rms

Extract the RMS of each module's output from diagnostics text.

diagnostics show_rms diagnostics.txt | sort -gr -k2 | head

diagnostics show_eigs

Analyze eigenvalue statistics of module outputs. Computes ratios that indicate how concentrated the variance is across eigen-directions (next-largest-ratio, top_ratio, 2norm/1norm, etc.).

diagnostics show_eigs diagnostics.txt

diagnostics param_importance

Compute a normalized importance score for each parameter. Importance is defined as value_mean * grad_mean * num_params, aggregated by module name prefixes and suffixes, and normalized so all scores sum to 1.0.

Two-file mode compares the importance outputs of two files and prints the ratio.

# Single file: analyze importance
diagnostics param_importance diagnostics.txt | sort -gr -k2 | head

# Two files: compare importance
diagnostics param_importance diag_epoch5.txt diag_epoch10.txt

diagnostics param_magnitude

Extract the mean absolute value of each parameter from diagnostics text.

Two-file mode compares the magnitude outputs of two files and prints the ratio.

# Single file: extract magnitudes
diagnostics param_magnitude diagnostics.txt

# Two files: compare magnitudes
diagnostics param_magnitude diag_epoch5.txt diag_epoch10.txt

diagnostics compare_epochs

Compare model parameters between two PyTorch checkpoints. For each float32 parameter, computes the RMS norm and normalized relative difference.

With --summarize, additionally aggregates the diffs by module name prefixes and suffixes.

# Compare two checkpoints
diagnostics compare_epochs exp/epoch-5.pt exp/epoch-10.pt

# With hierarchical summary
diagnostics compare_epochs exp/epoch-5.pt exp/epoch-10.pt --summarize

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nndiagnostics-0.1.8.tar.gz (19.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nndiagnostics-0.1.8-py3-none-any.whl (25.9 kB view details)

Uploaded Python 3

File details

Details for the file nndiagnostics-0.1.8.tar.gz.

File metadata

  • Download URL: nndiagnostics-0.1.8.tar.gz
  • Upload date:
  • Size: 19.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.20

File hashes

Hashes for nndiagnostics-0.1.8.tar.gz
Algorithm Hash digest
SHA256 9068ea7d1abf572d2d7ca61ffa3f729eadcb10ab414bfa537bf06bbee35a30a3
MD5 822a68720a04fc9f4c04b64cb6a37398
BLAKE2b-256 42b68938cea009dd2ea5756adcf4a3482a36390cb226b9033938ac248d30ec27

See more details on using hashes here.

File details

Details for the file nndiagnostics-0.1.8-py3-none-any.whl.

File metadata

  • Download URL: nndiagnostics-0.1.8-py3-none-any.whl
  • Upload date:
  • Size: 25.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.20

File hashes

Hashes for nndiagnostics-0.1.8-py3-none-any.whl
Algorithm Hash digest
SHA256 50478f8ecbc3e2e9faf07d55d7a3dae64e338c73f705f897a3749691f57e73de
MD5 6e8702c3ce36a8be32461c5a82f7877c
BLAKE2b-256 cfdc4fa0083b35db5966116a6bcdc7e41cc41ca534d55d76f72977c54b738f30

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page