A framework-agnostic library for monitoring deep learning training runs and diagnosing issues
Project description
training-doctor
Stop wasting GPU hours on broken training runs.
training-doctor monitors your deep learning training and tells you when something's wrong, why it's happening, and how to fix it.
WARNING Learning rate likely too high
Training is showing signs of instability typically caused by an excessive learning rate.
Evidence:
- loss: NaN values detected at step 2847
- loss: Loss spiked 5 times (>3.0 std)
- grad_norm: Gradient explosion detected (value=1.2e+6)
Suggested actions:
- Reduce learning rate from 1.00e-02 to ~3.00e-03
- Try gradient clipping (e.g., max_grad_norm=1.0)
- Consider using learning rate warmup
Confidence: 0.85
Why This Exists
Training large models is expensive and opaque:
- Loss plateaus and you don't know if it's normal
- Runs silently waste days of compute
- "Fixes" are guesswork based on vibes
- Existing tools show metrics but don't explain anything
training-doctor is the judgment layer. It watches your metrics and tells you what's actually wrong.
Installation
pip install training-doctor
With framework integrations:
pip install training-doctor[transformers] # HuggingFace Trainer
pip install training-doctor[lightning] # PyTorch Lightning
pip install training-doctor[wandb] # Load W&B logs
pip install training-doctor[all] # Everything
Quick Start
Option 1: Drop into Any Training Loop
from training_doctor import Doctor
doctor = Doctor()
for step, batch in enumerate(dataloader):
# Your normal training code
loss = model(batch)
loss.backward()
optimizer.step()
# Add this line - that's it
doctor.log(
step=step,
loss=loss.item(),
lr=optimizer.param_groups[0]['lr'],
)
Diagnoses print automatically when problems are detected.
Option 2: HuggingFace Transformers
from transformers import Trainer
from training_doctor.integrations import HuggingFaceCallback
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
callbacks=[HuggingFaceCallback()], # Add this
)
trainer.train()
Option 3: PyTorch Lightning
import pytorch_lightning as pl
from training_doctor.integrations import LightningCallback
trainer = pl.Trainer(
max_epochs=10,
callbacks=[LightningCallback()], # Add this
)
trainer.fit(model, datamodule)
Option 4: Analyze Existing Logs
from training_doctor import Doctor
doctor = Doctor()
# Load from CSV
doctor.load("training_logs.csv")
# Or from W&B
doctor.load("wandb/run-20240115_120000/")
# Or from TensorBoard
doctor.load("lightning_logs/version_0/")
# Run analysis
for diagnosis in doctor.analyze():
print(diagnosis)
What It Detects
Learning Rate Too High
Signals: Loss spikes, NaN values, gradient explosion, slow divergence
CRITICAL Learning rate likely too high
Evidence:
- loss: NaN values detected at step 1250
- grad_norm: Exploding gradients (>10.0): 8 occurrences
Suggested actions:
- Reduce learning rate from 3.00e-03 to ~9.00e-04
- Enable gradient clipping (e.g., max_grad_norm=1.0)
- Consider using learning rate warmup
Learning Rate Too Low
Signals: Flat loss, minimal progress, tiny gradients
WARNING Learning rate may be too low
Evidence:
- loss: Loss barely changed over 500 steps (0.3%)
- loss: Loss has extremely low variance (essentially flat)
Suggested actions:
- Increase learning rate from 1.00e-06 to ~3.00e-06
- Consider using a learning rate finder
- Verify data is being loaded correctly
Training Plateau
Signals: Loss stagnation + high gradient variance (often means batch size too small)
WARNING Training plateau detected
Evidence:
- loss: Loss plateaued for 100 steps (change: -0.8%)
- loss: Loss still at high level (not converged)
- grad_norm: High gradient variance during plateau (CV: 0.54)
Suggested actions:
- Increase batch size or gradient accumulation steps
- Target effective batch size of 500k-1M tokens for LLMs
- Reduce learning rate (current phase may need lower LR)
Overfitting
Signals: Train loss decreasing while eval loss increases
WARNING Overfitting detected
Evidence:
- train_loss: Train loss decreasing (-15.2%)
- eval_loss: Eval loss increasing (+8.3%)
- train_eval_gap: Eval loss 23% higher than train loss
Suggested actions:
- Add or increase dropout
- Use weight decay (L2 regularization)
- Implement early stopping based on eval loss
- Add more training data or data augmentation
Gradient Instability
Signals: NaN/Inf gradients, explosions, vanishing, high variance
CRITICAL Gradient instability detected
Evidence:
- grad_norm: NaN/Inf gradients detected (3 occurrences)
- grad_norm: Exploding gradients (>10.0): 12 occurrences
Suggested actions:
- Enable gradient clipping (e.g., max_grad_norm=1.0)
- Reduce learning rate significantly
- Check for division by zero in loss computation
What Metrics to Log
| Metric | How to Log | What It Enables |
|---|---|---|
| Loss | loss=loss.item() |
All detectors (required) |
| Learning Rate | lr=optimizer.param_groups[0]['lr'] |
Better LR diagnostics |
| Gradient Norm | grad_norm=clip_grad_norm_(...) |
Instability detection |
| Eval Loss | eval_loss=eval_loss.item() |
Overfitting detection |
Minimal (just loss):
doctor.log(step=step, loss=loss.item())
Recommended (full diagnostics):
doctor.log(
step=step,
loss=loss.item(),
lr=scheduler.get_last_lr()[0],
grad_norm=torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0),
)
With evaluation:
doctor.log(
step=step,
train_loss=train_loss.item(),
eval_loss=eval_loss.item(), # Log during eval steps
)
API Reference
Doctor
The main class that orchestrates everything.
from training_doctor import Doctor
# Create with default detectors
doctor = Doctor()
# Or customize
from training_doctor.detectors import LearningRateTooHighDetector
doctor = Doctor(detectors=[LearningRateTooHighDetector(loss_spike_threshold=2.5)])
Methods:
| Method | Description |
|---|---|
doctor.log(step, **metrics) |
Log metrics and check for issues |
doctor.check() |
Run detection without logging new data |
doctor.load(path) |
Load metrics from file (CSV, W&B, TensorBoard) |
doctor.analyze() |
Run full analysis on loaded data |
doctor.diagnoses |
List of all diagnoses found |
doctor.summary() |
Get summary statistics |
doctor.clear() |
Reset all data |
Diagnosis
What you get when an issue is detected.
diagnosis.problem # "Learning rate likely too high"
diagnosis.explanation # Why this is happening
diagnosis.evidence # List of Evidence objects
diagnosis.suggestions # List of action items
diagnosis.confidence # 0.0 to 1.0
diagnosis.severity # Severity.INFO, WARNING, or CRITICAL
diagnosis.step # When it was detected
Controlling Output
# Disable auto-printing
doctor.set_auto_print(False)
# Manually handle diagnoses
diagnoses = doctor.log(step=step, loss=loss)
for d in diagnoses:
if d.confidence > 0.5:
print(d)
# Or use the reporter directly
from training_doctor.reporters import ConsoleReporter
reporter = ConsoleReporter(use_color=True)
reporter.report(diagnosis)
How It Works
No Guessing
Every diagnosis requires multiple corroborating signals:
| Detector | Requires |
|---|---|
| LR Too High | NaN OR multiple spikes OR divergence + gradient issues |
| Plateau | Flat loss + high variance + still far from converged |
| Overfitting | Train↓ + Eval↑ + Growing gap |
Single noisy data points don't trigger false alarms.
Confidence Scores
Each diagnosis has a confidence score (0.0-1.0) based on how much evidence supports it:
< 0.3- Weak signal, might be noise0.3-0.6- Moderate signal, worth investigating> 0.6- Strong signal, likely a real problem
Cooldowns
Detectors have cooldown periods to avoid spamming the same diagnosis:
# Default: won't repeat same diagnosis for 200 steps
LearningRateTooHighDetector(cooldown_steps=200)
CSV Format
If logging to CSV for offline analysis:
step,loss,lr,grad_norm,eval_loss
0,4.521,0.0001,1.23,
100,4.102,0.0001,1.15,
200,3.854,0.0001,0.98,4.012
300,3.612,0.0001,1.02,
stepcolumn is required (oriteration,global_step)- Missing values are fine (eval_loss only logged sometimes)
- Column names are normalized automatically (
learning_rate→lr)
Extending
Custom Detectors
from training_doctor.detectors import BaseDetector
from training_doctor.diagnosis import Diagnosis, Evidence, Severity
class MyCustomDetector(BaseDetector):
name = "my_detector"
description = "Detects my custom issue"
min_data_points = 50
cooldown_steps = 100
@property
def required_metrics(self):
return ["loss", "my_metric"]
def detect(self, store):
# Your detection logic
my_metric_stats = store.compute_window_stats("my_metric", 100)
if my_metric_stats.mean > some_threshold:
return Diagnosis(
problem="My custom issue detected",
explanation="...",
evidence=[Evidence(metric="my_metric", observation="...")],
suggestions=["..."],
confidence=0.7,
severity=Severity.WARNING,
)
return None
# Use it
doctor = Doctor()
doctor.add_detector(MyCustomDetector())
Examples
Full Training Script
import torch
from torch.utils.data import DataLoader
from training_doctor import Doctor
model = MyModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10000)
dataloader = DataLoader(dataset, batch_size=32)
doctor = Doctor()
for step, batch in enumerate(dataloader):
optimizer.zero_grad()
loss = model(batch)
loss.backward()
# Get gradient norm before clipping
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
# Log to training-doctor
doctor.log(
step=step,
loss=loss.item(),
lr=scheduler.get_last_lr()[0],
grad_norm=grad_norm.item(),
)
# Regular logging
if step % 100 == 0:
print(f"Step {step}: loss={loss.item():.4f}")
# Final summary
print(f"\nTraining complete. Found {len(doctor.diagnoses)} issues.")
Analyzing a Failed Run
from training_doctor import Doctor
# Load the logs from your failed run
doctor = Doctor()
doctor.load("failed_run_logs.csv")
# Get all diagnoses
diagnoses = doctor.analyze()
# Print sorted by confidence
for d in sorted(diagnoses, key=lambda x: x.confidence, reverse=True):
print(f"\n[{d.confidence:.0%}] {d.problem}")
print(f" Step: {d.step}")
print(f" Suggestions:")
for s in d.suggestions[:2]:
print(f" - {s}")
License
MIT
Contributing
Issues and PRs welcome at https://github.com/SalahAlHaismawi/training-doctor
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file training_doctor-0.1.0.tar.gz.
File metadata
- Download URL: training_doctor-0.1.0.tar.gz
- Upload date:
- Size: 49.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bfe1f797829f6c33fc89a28d8a52b7ee866c32bf33119a5634cbcb4cbe2db3b7
|
|
| MD5 |
2d82fc95428166589302d3bf1dfb436a
|
|
| BLAKE2b-256 |
7f0facb82f592d8e09593d56e03fa7e22f797ffc6796da42440be7185c3227bb
|
File details
Details for the file training_doctor-0.1.0-py3-none-any.whl.
File metadata
- Download URL: training_doctor-0.1.0-py3-none-any.whl
- Upload date:
- Size: 37.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e61a20577d88a2435fca23121d68ce63dbc73d0a2d8dd9346dbed3b8c6e944d0
|
|
| MD5 |
f4f38dadf882716dc1306fd72367e6fe
|
|
| BLAKE2b-256 |
7ca2aa4c4f7244bc9b97986c46ca551d77f8337bc6c55b643d2099361a486113
|