PyTorch model health diagnostics — gradient checks, dead neuron detection, training verification. Built from an SRE perspective.
Project description
torchdiag
PyTorch model health diagnostics — built from an SRE perspective.
Stop guessing why your model isn't learning. torchdiag gives you five diagnostic commands that answer the questions that matter: Are gradients flowing? Are neurons alive? Did the optimizer actually update weights?
Installation
pip install torchdiag
Quick Start
import torch
import torch.nn as nn
import torchdiag
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 64),
nn.ReLU(),
nn.Linear(64, 10),
)
# 1. Model overview
torchdiag.summary(model)
# 2. Check for dead neurons
x = torch.randn(100, 784)
torchdiag.check_dead_neurons(model, x)
# 3. Verify a full training step works
torchdiag.verify_step(
model,
torch.optim.Adam(model.parameters()),
nn.CrossEntropyLoss(),
torch.randn(32, 784),
torch.randint(0, 10, (32,)),
)
# 4. Check gradient health (after backward)
x = torch.randn(32, 784)
loss = nn.CrossEntropyLoss()(model(x), torch.randint(0, 10, (32,)))
loss.backward()
torchdiag.check_gradients(model)
# 5. Memory usage
torchdiag.memory_report()
What Each Command Does
torchdiag.summary(model)
Prints parameter count per layer, total/trainable/frozen breakdown, memory footprint, device placement, and dtype distribution. Flags issues like all-frozen parameters or split-device models.
torchdiag.check_gradients(model)
Call after loss.backward(). Reports gradient mean, max, and min per layer. Flags vanishing gradients (max < 1e-7), exploding gradients (max > 100), and disconnected parameters (None gradients).
torchdiag.check_dead_neurons(model, sample_input)
Runs a forward pass and checks activation layers for neurons that output zero for every sample. Reports dead neuron count and percentage per layer. Flags critical layers (>50% dead) and warnings (>20% dead).
torchdiag.verify_step(model, optimizer, loss_fn, x, y)
Runs one complete training step (forward → loss → backward → step) and verifies each stage works: output shape is correct, loss is finite, gradients are computed, and parameters actually change.
torchdiag.memory_report()
Reports CPU peak RSS, GPU memory (allocated, cached, peak, total) per device, and MPS memory on Apple Silicon. Flags when GPU utilization exceeds 90%.
Why This Exists
Most PyTorch debugging happens by staring at loss curves. That's like monitoring a distributed system by watching a single dashboard number.
torchdiag brings SRE observability practices to model training:
- Measure, don't guess — print the actual gradient values, don't assume they're fine
- Check preconditions — verify the training step works before running 100 epochs
- Detect silent failures — dead neurons and None gradients don't raise errors
Requirements
- Python 3.8+
- PyTorch 2.0+
License
MIT
Author
Aditya Mehra — Staff Engineer, IEEE Senior Member, PyTorch ecosystem contributor.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchdiag-0.1.0.tar.gz.
File metadata
- Download URL: torchdiag-0.1.0.tar.gz
- Upload date:
- Size: 8.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
50a2fad840ddcc4dc38b0b64b80cfedd4f05970cc931ad18f8be989db5a502b3
|
|
| MD5 |
d8353fc4cdade43de4bb2258f4462d6e
|
|
| BLAKE2b-256 |
9d9f980225499c0ecc8b2cc77f07add8b34599dbe686b9998a3eb3c91d11c9d3
|
File details
Details for the file torchdiag-0.1.0-py3-none-any.whl.
File metadata
- Download URL: torchdiag-0.1.0-py3-none-any.whl
- Upload date:
- Size: 8.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
72351be775210e3506ec5e40287b3493260e7b2f17422af7447283f2ace7e734
|
|
| MD5 |
8748890aa110fdadf05a01dd5dc9b87c
|
|
| BLAKE2b-256 |
cb73e8bc66edd4896dc1649e288bd953ad0f895015bf8b86c67cbda3f85fec99
|