Skip to main content

PyTorch model health diagnostics — gradient checks, dead neuron detection, training verification. Built from an SRE perspective.

Project description

torchdiag

PyTorch License: MIT Python 3.8+

PyTorch model health diagnostics — built from an SRE perspective.

Stop guessing why your model isn't learning. torchdiag gives you five diagnostic commands that answer the questions that matter: Are gradients flowing? Are neurons alive? Did the optimizer actually update weights?

Installation

pip install torchdiag

Quick Start

import torch
import torch.nn as nn
import torchdiag

model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 64),
    nn.ReLU(),
    nn.Linear(64, 10),
)

# 1. Model overview
torchdiag.summary(model)

# 2. Check for dead neurons
x = torch.randn(100, 784)
torchdiag.check_dead_neurons(model, x)

# 3. Verify a full training step works
torchdiag.verify_step(
    model,
    torch.optim.Adam(model.parameters()),
    nn.CrossEntropyLoss(),
    torch.randn(32, 784),
    torch.randint(0, 10, (32,)),
)

# 4. Check gradient health (after backward)
x = torch.randn(32, 784)
loss = nn.CrossEntropyLoss()(model(x), torch.randint(0, 10, (32,)))
loss.backward()
torchdiag.check_gradients(model)

# 5. Memory usage
torchdiag.memory_report()

What Each Command Does

torchdiag.summary(model)

Prints parameter count per layer, total/trainable/frozen breakdown, memory footprint, device placement, and dtype distribution. Flags issues like all-frozen parameters or split-device models.

torchdiag.check_gradients(model)

Call after loss.backward(). Reports gradient mean, max, and min per layer. Flags vanishing gradients (max < 1e-7), exploding gradients (max > 100), and disconnected parameters (None gradients).

torchdiag.check_dead_neurons(model, sample_input)

Runs a forward pass and checks activation layers for neurons that output zero for every sample. Reports dead neuron count and percentage per layer. Flags critical layers (>50% dead) and warnings (>20% dead).

torchdiag.verify_step(model, optimizer, loss_fn, x, y)

Runs one complete training step (forward → loss → backward → step) and verifies each stage works: output shape is correct, loss is finite, gradients are computed, and parameters actually change.

torchdiag.memory_report()

Reports CPU peak RSS, GPU memory (allocated, cached, peak, total) per device, and MPS memory on Apple Silicon. Flags when GPU utilization exceeds 90%.

Why This Exists

Most PyTorch debugging happens by staring at loss curves. That's like monitoring a distributed system by watching a single dashboard number.

torchdiag brings SRE observability practices to model training:

  • Measure, don't guess — print the actual gradient values, don't assume they're fine
  • Check preconditions — verify the training step works before running 100 epochs
  • Detect silent failures — dead neurons and None gradients don't raise errors

Requirements

  • Python 3.8+
  • PyTorch 2.0+

License

MIT

Author

Aditya Mehra — Staff Engineer, IEEE Senior Member, PyTorch ecosystem contributor.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchdiag-0.1.0.tar.gz (8.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchdiag-0.1.0-py3-none-any.whl (8.4 kB view details)

Uploaded Python 3

File details

Details for the file torchdiag-0.1.0.tar.gz.

File metadata

  • Download URL: torchdiag-0.1.0.tar.gz
  • Upload date:
  • Size: 8.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for torchdiag-0.1.0.tar.gz
Algorithm Hash digest
SHA256 50a2fad840ddcc4dc38b0b64b80cfedd4f05970cc931ad18f8be989db5a502b3
MD5 d8353fc4cdade43de4bb2258f4462d6e
BLAKE2b-256 9d9f980225499c0ecc8b2cc77f07add8b34599dbe686b9998a3eb3c91d11c9d3

See more details on using hashes here.

File details

Details for the file torchdiag-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: torchdiag-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 8.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for torchdiag-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 72351be775210e3506ec5e40287b3493260e7b2f17422af7447283f2ace7e734
MD5 8748890aa110fdadf05a01dd5dc9b87c
BLAKE2b-256 cb73e8bc66edd4896dc1649e288bd953ad0f895015bf8b86c67cbda3f85fec99

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page