Skip to main content

Lightweight PyTorch tensor diagnostics hooks for training loops

Project description

NN diagnostics

A useful tool to dump diagnostics info from checkpoint.

Install

pip install nndiagnostics

Quick Start

Dump diagnostics infomation

  1. Integrate diagnostics in your training loop
from diagnostics import maybe_attach_diagnostics

diag = maybe_attach_diagnostics(model)

for step, batch in enumerate(train_loader):
    loss = train_step(batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if diag and diag.should_stop(step, stop_after_steps=5):
        diag.print(f"{args.exp_dir}/diagnostics-step-{step}.txt")
        break
  1. Dump diagnostics information (by setting env DUMP_DIAGNOSTICS)
DUMP_DIAGNOSTICS=1 python train.py

Add Inf/NaN Check Hooks

Register forward and backward hooks on every module and parameter to detect non-finite values (NaN/Inf) during training. When a non-finite value is detected, a warning is logged with the module/parameter name.

  1. Register inf check hooks to your model
from diagnostics import maybe_register_inf_check_hooks

maybe_register_inf_check_hooks(model)

for step, batch in enumerate(train_loader):
    loss = train_step(batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
  1. Enable by setting the environment variable INF_CHECK:
INF_CHECK=1 python train.py

If any module output, gradient, or parameter gradient becomes non-finite, you will see warnings like:

WARNING: The sum of encoder.layers.3.output is not finite
WARNING: The sum of encoder.layers.3.grad[0] is not finite
WARNING: The sum of encoder.layers.3.weight.param_grad is not finite

CLI Tools

The package installs a diagnostics command with several subcommands for post-processing diagnostics output. All tools support:

  • Reading from a file or stdin (via pipe)
  • Writing to a file with -o/--output-file (default: stdout)

diagnostics show_infinite

Detect transitions to non-finite values (NaN/Inf). Prints lines containing "finite" that follow lines NOT containing "finite" -- highlighting where non-finite values first appear.

diagnostics show_infinite diagnostics.txt
cat diagnostics.txt | diagnostics show_infinite

diagnostics show_rms

Extract the RMS of each module's output from diagnostics text.

diagnostics show_rms diagnostics.txt | sort -gr -k2 | head

diagnostics show_eigs

Analyze eigenvalue statistics of module outputs. Computes ratios that indicate how concentrated the variance is across eigen-directions (next-largest-ratio, top_ratio, 2norm/1norm, etc.).

diagnostics show_eigs diagnostics.txt

diagnostics param_importance

Compute a normalized importance score for each parameter. Importance is defined as value_mean * grad_mean * num_params, aggregated by module name prefixes and suffixes, and normalized so all scores sum to 1.0.

Two-file mode compares the importance outputs of two files and prints the ratio.

# Single file: analyze importance
diagnostics param_importance diagnostics.txt | sort -gr -k2 | head

# Two files: compare importance
diagnostics param_importance diag_epoch5.txt diag_epoch10.txt

diagnostics param_magnitude

Extract the mean absolute value of each parameter from diagnostics text.

Two-file mode compares the magnitude outputs of two files and prints the ratio.

# Single file: extract magnitudes
diagnostics param_magnitude diagnostics.txt

# Two files: compare magnitudes
diagnostics param_magnitude diag_epoch5.txt diag_epoch10.txt

diagnostics compare_epochs

Compare model parameters between two PyTorch checkpoints. For each float32 parameter, computes the RMS norm and normalized relative difference.

With --summarize, additionally aggregates the diffs by module name prefixes and suffixes.

# Compare two checkpoints
diagnostics compare_epochs exp/epoch-5.pt exp/epoch-10.pt

# With hierarchical summary
diagnostics compare_epochs exp/epoch-5.pt exp/epoch-10.pt --summarize

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nndiagnostics-0.1.9.tar.gz (20.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nndiagnostics-0.1.9-py3-none-any.whl (27.5 kB view details)

Uploaded Python 3

File details

Details for the file nndiagnostics-0.1.9.tar.gz.

File metadata

  • Download URL: nndiagnostics-0.1.9.tar.gz
  • Upload date:
  • Size: 20.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.20

File hashes

Hashes for nndiagnostics-0.1.9.tar.gz
Algorithm Hash digest
SHA256 7f6d038a1a9d7c892ba1d6913073d758d3fbb34ed34445a08258592fcb44305c
MD5 a002769df987c24c87e971109f993a13
BLAKE2b-256 7422831c50ea7aae45a8e3ef35a77feb6e1327954477166abf8ebe81e7d2b9aa

See more details on using hashes here.

File details

Details for the file nndiagnostics-0.1.9-py3-none-any.whl.

File metadata

  • Download URL: nndiagnostics-0.1.9-py3-none-any.whl
  • Upload date:
  • Size: 27.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.20

File hashes

Hashes for nndiagnostics-0.1.9-py3-none-any.whl
Algorithm Hash digest
SHA256 b9042808e77e012b73192215c2906a9e2d5fbb5f8b573610bfcc1b239370cdab
MD5 607d56ba7c89e691769e999f26573a3b
BLAKE2b-256 b3e8d1ba739c5cbfcf7f2b89aeee2c5e2af0b4da4bcf82955ec0c90dd12f61d8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page