Real-time gradient pathology detection for PyTorch
Project description
torch-surgeon
Real-time gradient pathology detection for PyTorch — in 2 lines.
A loss curve is a lagging indicator. By the time it shows a problem, vanishing or exploding gradients have been compounding for hundreds of steps. torch-surgeon attaches diagnostic hooks to your model and surfaces per-layer pathologies in real time, before they compound into an unrecoverable run.
Install
pip install torch-surgeon
Usage
from torch_surgeon import Surgeon
surgeon = Surgeon(model, rules="default")
surgeon.attach()
# ... your existing training loop, unchanged ...
for epoch in range(epochs):
loss = criterion(model(x), y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
report = surgeon.report() # per-layer stats dict
surgeon.detach() # clean removal of all hooks
What it detects
| Pathology | Detection method |
|---|---|
| Vanishing gradients | Per-layer norm ratio drops below threshold vs EMA baseline |
| Exploding gradients | Per-layer norm ratio exceeds threshold vs EMA baseline |
| Stagnant layers | Norm near-zero for N consecutive steps — layer stopped learning |
Custom rules
surgeon = Surgeon(model, rules={
"vanishing_threshold": 0.01, # default
"exploding_threshold": 100.0, # default
"stagnant_steps": 50, # default
"log_every": 10, # print summary every N steps
"plot": True, # live matplotlib plot
"verbose": True,
})
How it works
torch-surgeon uses PyTorch's register_full_backward_hook API to intercept gradients
at every leaf layer during the backward pass. Statistics (mean, std, norm) are computed
inside the hook and the raw gradient tensor is discarded immediately — keeping overhead
under 1% on typical training loops.
Pathology detection uses an exponential moving average (EMA) baseline per layer rather than fixed thresholds — so it generalises across architectures without manual tuning.
Performance
Sub-1% training overhead on standard loops. Validated against 100-step timing benchmarks on Linear/ReLU networks. Stats computed in-hook; no tensors stored between steps.
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_surgeon-0.1.0.tar.gz.
File metadata
- Download URL: torch_surgeon-0.1.0.tar.gz
- Upload date:
- Size: 10.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a0b7a342e4e6a5f7186ca04f445475135f71f3953bfec2c51a105e0086588d61
|
|
| MD5 |
24c102cd0543a7f03521053101d6ad6b
|
|
| BLAKE2b-256 |
0e8f0c28ab0b77fea35d8714ffb0f801015354a67e7cd4096b8bd3c68c5eb6ff
|
File details
Details for the file torch_surgeon-0.1.0-py3-none-any.whl.
File metadata
- Download URL: torch_surgeon-0.1.0-py3-none-any.whl
- Upload date:
- Size: 8.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2ad8d5f76d75daae7be4169fc759bea339551bcc689ba613618ed194aa1b429d
|
|
| MD5 |
6c00abf56c3938753d03b44cc2fb81a7
|
|
| BLAKE2b-256 |
04c25d25bbcf73b98fbd32fa0c38c83c902543d50a471927821704421606f8f1
|