Skip to main content

A framework-agnostic library for monitoring deep learning training runs and diagnosing issues

Project description

training-doctor

Stop wasting GPU hours on broken training runs.

training-doctor monitors your deep learning training and tells you when something's wrong, why it's happening, and how to fix it.

WARNING Learning rate likely too high
Training is showing signs of instability typically caused by an excessive learning rate.

Evidence:
  - loss: NaN values detected at step 2847
  - loss: Loss spiked 5 times (>3.0 std)
  - grad_norm: Gradient explosion detected (value=1.2e+6)

Suggested actions:
  - Reduce learning rate from 1.00e-02 to ~3.00e-03
  - Try gradient clipping (e.g., max_grad_norm=1.0)
  - Consider using learning rate warmup

Confidence: 0.85

Why This Exists

Training large models is expensive and opaque:

  • Loss plateaus and you don't know if it's normal
  • Runs silently waste days of compute
  • "Fixes" are guesswork based on vibes
  • Existing tools show metrics but don't explain anything

training-doctor is the judgment layer. It watches your metrics and tells you what's actually wrong.


Installation

pip install training-doctor

With framework integrations:

pip install training-doctor[transformers]  # HuggingFace Trainer
pip install training-doctor[lightning]     # PyTorch Lightning
pip install training-doctor[wandb]         # Load W&B logs
pip install training-doctor[all]           # Everything

Quick Start

Option 1: Drop into Any Training Loop

from training_doctor import Doctor

doctor = Doctor()

for step, batch in enumerate(dataloader):
    # Your normal training code
    loss = model(batch)
    loss.backward()
    optimizer.step()

    # Add this line - that's it
    doctor.log(
        step=step,
        loss=loss.item(),
        lr=optimizer.param_groups[0]['lr'],
    )

Diagnoses print automatically when problems are detected.

Option 2: HuggingFace Transformers

from transformers import Trainer
from training_doctor.integrations import HuggingFaceCallback

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    callbacks=[HuggingFaceCallback()],  # Add this
)
trainer.train()

Option 3: PyTorch Lightning

import pytorch_lightning as pl
from training_doctor.integrations import LightningCallback

trainer = pl.Trainer(
    max_epochs=10,
    callbacks=[LightningCallback()],  # Add this
)
trainer.fit(model, datamodule)

Option 4: Analyze Existing Logs

from training_doctor import Doctor

doctor = Doctor()

# Load from CSV
doctor.load("training_logs.csv")

# Or from W&B
doctor.load("wandb/run-20240115_120000/")

# Or from TensorBoard
doctor.load("lightning_logs/version_0/")

# Run analysis
for diagnosis in doctor.analyze():
    print(diagnosis)

What It Detects

Learning Rate Too High

Signals: Loss spikes, NaN values, gradient explosion, slow divergence

CRITICAL Learning rate likely too high

Evidence:
  - loss: NaN values detected at step 1250
  - grad_norm: Exploding gradients (>10.0): 8 occurrences

Suggested actions:
  - Reduce learning rate from 3.00e-03 to ~9.00e-04
  - Enable gradient clipping (e.g., max_grad_norm=1.0)
  - Consider using learning rate warmup

Learning Rate Too Low

Signals: Flat loss, minimal progress, tiny gradients

WARNING Learning rate may be too low

Evidence:
  - loss: Loss barely changed over 500 steps (0.3%)
  - loss: Loss has extremely low variance (essentially flat)

Suggested actions:
  - Increase learning rate from 1.00e-06 to ~3.00e-06
  - Consider using a learning rate finder
  - Verify data is being loaded correctly

Training Plateau

Signals: Loss stagnation + high gradient variance (often means batch size too small)

WARNING Training plateau detected

Evidence:
  - loss: Loss plateaued for 100 steps (change: -0.8%)
  - loss: Loss still at high level (not converged)
  - grad_norm: High gradient variance during plateau (CV: 0.54)

Suggested actions:
  - Increase batch size or gradient accumulation steps
  - Target effective batch size of 500k-1M tokens for LLMs
  - Reduce learning rate (current phase may need lower LR)

Overfitting

Signals: Train loss decreasing while eval loss increases

WARNING Overfitting detected

Evidence:
  - train_loss: Train loss decreasing (-15.2%)
  - eval_loss: Eval loss increasing (+8.3%)
  - train_eval_gap: Eval loss 23% higher than train loss

Suggested actions:
  - Add or increase dropout
  - Use weight decay (L2 regularization)
  - Implement early stopping based on eval loss
  - Add more training data or data augmentation

Gradient Instability

Signals: NaN/Inf gradients, explosions, vanishing, high variance

CRITICAL Gradient instability detected

Evidence:
  - grad_norm: NaN/Inf gradients detected (3 occurrences)
  - grad_norm: Exploding gradients (>10.0): 12 occurrences

Suggested actions:
  - Enable gradient clipping (e.g., max_grad_norm=1.0)
  - Reduce learning rate significantly
  - Check for division by zero in loss computation

What Metrics to Log

Metric How to Log What It Enables
Loss loss=loss.item() All detectors (required)
Learning Rate lr=optimizer.param_groups[0]['lr'] Better LR diagnostics
Gradient Norm grad_norm=clip_grad_norm_(...) Instability detection
Eval Loss eval_loss=eval_loss.item() Overfitting detection

Minimal (just loss):

doctor.log(step=step, loss=loss.item())

Recommended (full diagnostics):

doctor.log(
    step=step,
    loss=loss.item(),
    lr=scheduler.get_last_lr()[0],
    grad_norm=torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0),
)

With evaluation:

doctor.log(
    step=step,
    train_loss=train_loss.item(),
    eval_loss=eval_loss.item(),  # Log during eval steps
)

API Reference

Doctor

The main class that orchestrates everything.

from training_doctor import Doctor

# Create with default detectors
doctor = Doctor()

# Or customize
from training_doctor.detectors import LearningRateTooHighDetector
doctor = Doctor(detectors=[LearningRateTooHighDetector(loss_spike_threshold=2.5)])

Methods:

Method Description
doctor.log(step, **metrics) Log metrics and check for issues
doctor.check() Run detection without logging new data
doctor.load(path) Load metrics from file (CSV, W&B, TensorBoard)
doctor.analyze() Run full analysis on loaded data
doctor.diagnoses List of all diagnoses found
doctor.summary() Get summary statistics
doctor.clear() Reset all data

Diagnosis

What you get when an issue is detected.

diagnosis.problem      # "Learning rate likely too high"
diagnosis.explanation  # Why this is happening
diagnosis.evidence     # List of Evidence objects
diagnosis.suggestions  # List of action items
diagnosis.confidence   # 0.0 to 1.0
diagnosis.severity     # Severity.INFO, WARNING, or CRITICAL
diagnosis.step         # When it was detected

Controlling Output

# Disable auto-printing
doctor.set_auto_print(False)

# Manually handle diagnoses
diagnoses = doctor.log(step=step, loss=loss)
for d in diagnoses:
    if d.confidence > 0.5:
        print(d)

# Or use the reporter directly
from training_doctor.reporters import ConsoleReporter
reporter = ConsoleReporter(use_color=True)
reporter.report(diagnosis)

How It Works

No Guessing

Every diagnosis requires multiple corroborating signals:

Detector Requires
LR Too High NaN OR multiple spikes OR divergence + gradient issues
Plateau Flat loss + high variance + still far from converged
Overfitting Train↓ + Eval↑ + Growing gap

Single noisy data points don't trigger false alarms.

Confidence Scores

Each diagnosis has a confidence score (0.0-1.0) based on how much evidence supports it:

  • < 0.3 - Weak signal, might be noise
  • 0.3-0.6 - Moderate signal, worth investigating
  • > 0.6 - Strong signal, likely a real problem

Cooldowns

Detectors have cooldown periods to avoid spamming the same diagnosis:

# Default: won't repeat same diagnosis for 200 steps
LearningRateTooHighDetector(cooldown_steps=200)

CSV Format

If logging to CSV for offline analysis:

step,loss,lr,grad_norm,eval_loss
0,4.521,0.0001,1.23,
100,4.102,0.0001,1.15,
200,3.854,0.0001,0.98,4.012
300,3.612,0.0001,1.02,
  • step column is required (or iteration, global_step)
  • Missing values are fine (eval_loss only logged sometimes)
  • Column names are normalized automatically (learning_ratelr)

Extending

Custom Detectors

from training_doctor.detectors import BaseDetector
from training_doctor.diagnosis import Diagnosis, Evidence, Severity

class MyCustomDetector(BaseDetector):
    name = "my_detector"
    description = "Detects my custom issue"
    min_data_points = 50
    cooldown_steps = 100

    @property
    def required_metrics(self):
        return ["loss", "my_metric"]

    def detect(self, store):
        # Your detection logic
        my_metric_stats = store.compute_window_stats("my_metric", 100)

        if my_metric_stats.mean > some_threshold:
            return Diagnosis(
                problem="My custom issue detected",
                explanation="...",
                evidence=[Evidence(metric="my_metric", observation="...")],
                suggestions=["..."],
                confidence=0.7,
                severity=Severity.WARNING,
            )
        return None

# Use it
doctor = Doctor()
doctor.add_detector(MyCustomDetector())

Examples

Full Training Script

import torch
from torch.utils.data import DataLoader
from training_doctor import Doctor

model = MyModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10000)
dataloader = DataLoader(dataset, batch_size=32)

doctor = Doctor()

for step, batch in enumerate(dataloader):
    optimizer.zero_grad()
    loss = model(batch)
    loss.backward()

    # Get gradient norm before clipping
    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    optimizer.step()
    scheduler.step()

    # Log to training-doctor
    doctor.log(
        step=step,
        loss=loss.item(),
        lr=scheduler.get_last_lr()[0],
        grad_norm=grad_norm.item(),
    )

    # Regular logging
    if step % 100 == 0:
        print(f"Step {step}: loss={loss.item():.4f}")

# Final summary
print(f"\nTraining complete. Found {len(doctor.diagnoses)} issues.")

Analyzing a Failed Run

from training_doctor import Doctor

# Load the logs from your failed run
doctor = Doctor()
doctor.load("failed_run_logs.csv")

# Get all diagnoses
diagnoses = doctor.analyze()

# Print sorted by confidence
for d in sorted(diagnoses, key=lambda x: x.confidence, reverse=True):
    print(f"\n[{d.confidence:.0%}] {d.problem}")
    print(f"  Step: {d.step}")
    print(f"  Suggestions:")
    for s in d.suggestions[:2]:
        print(f"    - {s}")

License

MIT


Contributing

Issues and PRs welcome at https://github.com/SalahAlHaismawi/training-doctor

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

training_doctor-0.1.0.tar.gz (49.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

training_doctor-0.1.0-py3-none-any.whl (37.2 kB view details)

Uploaded Python 3

File details

Details for the file training_doctor-0.1.0.tar.gz.

File metadata

  • Download URL: training_doctor-0.1.0.tar.gz
  • Upload date:
  • Size: 49.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for training_doctor-0.1.0.tar.gz
Algorithm Hash digest
SHA256 bfe1f797829f6c33fc89a28d8a52b7ee866c32bf33119a5634cbcb4cbe2db3b7
MD5 2d82fc95428166589302d3bf1dfb436a
BLAKE2b-256 7f0facb82f592d8e09593d56e03fa7e22f797ffc6796da42440be7185c3227bb

See more details on using hashes here.

File details

Details for the file training_doctor-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for training_doctor-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e61a20577d88a2435fca23121d68ce63dbc73d0a2d8dd9346dbed3b8c6e944d0
MD5 f4f38dadf882716dc1306fd72367e6fe
BLAKE2b-256 7ca2aa4c4f7244bc9b97986c46ca551d77f8337bc6c55b643d2099361a486113

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page