Skip to main content

Thermodynamic diagnostics for neural network training — see what your optimizer is actually doing.

Project description

Thermoclaw

Thermodynamic diagnostics for neural network training — see what your optimiser is actually doing.

Thermoclaw wraps any PyTorch optimizer and exposes real-time thermodynamic quantities: entropy production, internal vs external entropy flow, weight-decay collapse detection, equilibrium detection, and actionable per-layer diagnostics.


It found the problem. Acting on it worked.

We trained a GPT-2 small (124M params) from scratch on WikiText-103 with SGD + momentum and weight_decay=5.0. At step 19, CollapseDetector flagged HIGH-confidence weight decay collapse:

[HIGH] Weight decay is eroding 2 embedding layers: param norms dropped
36% during training. Reduce weight_decay.

We branched at that point — one arm continued with wd=5.0, one arm reduced to wd=0.01.

Run Final PPL (600 steps post-branch)
Unmodified (wd=5.0 throughout) 50,257 (vocab size — model completely dead)
Thermoclaw intervention (wd→0.01) 1,377
Improvement 36× lower PPL

Replicated identically across 3 seeds (42, 137, 2024). The CONTINUE arm locks at ppl = vocab_size — outputting a uniform distribution over all tokens, zero information. The INTERVENE arm learns.

No hyperparameter search. No manual inspection. One warning, one change.

That's what Thermoclaw does. The warning is causal — not correlational.


Install

pip install thermoclaw                    # core (torch + numpy only)
pip install thermoclaw[viz]               # + matplotlib dashboards
pip install thermoclaw[hf]                # + HuggingFace Trainer callback
pip install thermoclaw[all]               # everything

Quick Start

Observe any optimizer

from thermoclaw import Observer, diagnose

model = YourModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
observer = Observer(model, optimizer)

for batch in loader:
    loss = criterion(model(batch))
    loss.backward()
    optimizer.step()
    observer.step(loss=loss.item())
    optimizer.zero_grad()

report = diagnose(observer)
print(report)
observer.plot_dashboard(save_path='dashboard.png')

d_iS / d_eS entropy split (the killer feature)

from thermoclaw import Observer, EntropySplit, diagnose

observer = Observer(model, optimizer)
splitter = EntropySplit(model, optimizer, observer)

for batch in loader:
    loss = criterion(model(batch))
    loss.backward()
    optimizer.step()
    observer.step(loss=loss.item())
    splitter.step()                        # ← decomposes entropy
    optimizer.zero_grad()

report = diagnose(observer, splitter)
print(report)
# Output:
#   [HIGH] Weight decay is the dominant entropy source for 12 attention
#   layers (mean R_ie=4.2). Consider reducing weight_decay for attention
#   layers by 4-8×, or excluding them.

splitter.plot_entropy_split(save_path='entropy_split.png')

CollapseDetector — catch weight-decay collapse in real time

Weight decay can silently erode parameter capacity faster than gradient signal can rebuild it. The observable signature is ‖g‖/‖θ‖ (grad/param ratio) dropping over time. CollapseDetector tracks this directly — no calibration period required, no incommensurable quantities.

from thermoclaw import CollapseDetector

detector = CollapseDetector(model, optimizer)

for batch in loader:
    loss.backward()
    optimizer.step()
    detector.step()                        # ← call before zero_grad
    optimizer.zero_grad()

    # Mid-training intervention trigger
    if detector.is_collapsing:
        for pg in optimizer.param_groups:
            pg['weight_decay'] *= 0.1

# Post-training report
recs = detector.get_recommendations()
# → ["[HIGH] Weight decay collapse in 5 mlp layers: grad/param ratio
#     dropped 4.2× from early to late training. Reduce weight_decay."]

is_collapsing returns True as soon as any HIGH or MEDIUM collapse signal fires — use it for automated mid-training interventions. Validated with a controlled 3-seed experiment (GPT-2 small, SGD + momentum, WikiText-103): fires at step 19 with HIGH confidence, PPL gap +48,880 vs unmodified arm (which locks at vocab-size PPL — a complete dead-end).

ThermoScheduler (drop-in replacement)

One line to replace your cosine scheduler. Keeps your optimizer, just makes the schedule thermodynamically aware:

from thermoclaw import ThermoScheduler

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = ThermoScheduler(optimizer, total_steps=10000)

for step, batch in enumerate(loader):
    loss = criterion(model(batch))
    loss.backward()
    optimizer.step()
    scheduler.step()          # ← replaces cosine_scheduler.step()
    optimizer.zero_grad()

HuggingFace integration (one line)

from thermoclaw.integrations.huggingface import ThermoclawCallback

trainer = Trainer(
    model=model,
    args=training_args,
    callbacks=[ThermoclawCallback()],      # ← that's it
)
trainer.train()
# Dashboard PNG, CSV, and diagnostic report saved to output_dir automatically

What Thermoclaw measures

Quantity Symbol What it means
Entropy production σ_l = η‖g‖² How much "thermodynamic work" each layer is doing
Entropy ratio r_l = σ/σ* 1.0 = equilibrium. <0.85 = under-trained. >1.15 = over-trained
External entropy d_eS Entropy that reduces loss (productive learning)
Internal entropy d_iS Entropy from weight decay, momentum, noise (overhead)
R_ie = d_iS/d_eS The diagnostic ratio. >2 = warning, >5 = critical
Grad/param ratio ‖g‖/‖θ‖ CollapseDetector signal. Dropping trend = weight decay eroding capacity
Dispersion D = Var(r_l) Inter-layer training uniformity
Gradient alignment ρ = cos(g_t, g_{t-1}) Step coherence. Negative = oscillation
Parameter distance E = ‖θ−θ₀‖² How far weights have moved from init

A note on ρ and momentum (non-obvious)

High momentum (β₁ → 1) increases ρ, not decreases it — momentum smooths consecutive gradient vectors, making them more correlated. This means ρ is not a reliable signal for over-damped momentum. If β₁ is too high, use equilibrium fraction instead: over-damped runs show low eq_fraction despite high ρ, because damped updates can't reach homeostatic σ* levels. Validated at Pythia-410M scale: β₁=0.999 produced ρ=0.63 vs baseline ρ=0.39, with eq_fraction=0.15 vs 0.59.

The d_iS / d_eS decomposition

Standard training observes total loss and calls it a day. But total entropy production σ conflates two fundamentally different thermodynamic flows:

  • d_eS (external) — gradient-driven parameter updates that reduce loss. This is productive work.
  • d_iS (internal) — entropy from weight decay, momentum friction, stochastic noise. This is heat.

When d_iS >> d_eS, the optimizer is spending most of its entropy budget on overhead. Thermoclaw decomposes d_iS further into:

  • d_iS_wd — weight decay contribution
  • d_iS_momentum — momentum friction
  • d_iS_noise — stochastic gradient noise

This tells you exactly which layers, at which step, are wasting compute — and why.

Per-layer param groups

For full per-layer resolution, use make_param_groups:

from thermoclaw import make_param_groups

groups = make_param_groups(model, lr=3e-4, weight_decay=0.01)
optimizer = torch.optim.AdamW(groups)
observer = Observer(model, optimizer)

Confidence scoring

Recommendations are conservative. Thermoclaw only flags issues where the physics signal is unambiguous:

  • [HIGH] — Single dominant source (>60% of d_iS), R_ie > 5, consistent across regions. Safe to act on.
  • [MEDIUM] — Clear signal but moderate R_ie (2-5). Worth investigating.
  • [LOW] — Signal present but multiple sources contribute. Informational only.

Wrong recommendations that sound authoritative destroy trust faster than no recommendations at all. Thermoclaw would rather under-claim than over-claim.

Validated

Three-tier validation on H100 80GB (Pythia-410M, WikiText-103, bfloat16):

Tier Test Result
T1: Analytical σ, ρ, d_iS_wd, d_eS+d_iS=σ, E, D — 8 ground-truth checks 8/8 PASS
T2A: High LR lr=3e-2 → R_ie=1.6×10²⁰, eq=0.014, flagged HIGH PASS
T2B: High WD wd=5.0 → unhealthy, flagged MEDIUM PASS
T2C: Overdamped β₁=0.999 → ρ=0.63 (vs baseline 0.39), eq=0.15, flagged HIGH PASS
T2D: Baseline lr=3e-4/wd=0.01/β₁=0.9 → no collapse or WD pathology recs PASS
T3: Intervention CollapseDetector fires step 19 HIGH (SGD wd=5.0), PPL gap +48,880 vs dead arm (3/3 seeds) PASS

Origin

Thermoclaw's thermodynamic framework comes from the EPTO (Entropy-Production Targeted Optimisation) research project. The key insight: neural network training is a non-equilibrium thermodynamic process, and the quantities that matter (entropy production, entropy ratios, equilibrium fraction) can be measured for any optimizer, not just EPTO.

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

thermoclaw-0.1.4.tar.gz (58.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

thermoclaw-0.1.4-py3-none-any.whl (54.9 kB view details)

Uploaded Python 3

File details

Details for the file thermoclaw-0.1.4.tar.gz.

File metadata

  • Download URL: thermoclaw-0.1.4.tar.gz
  • Upload date:
  • Size: 58.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.7

File hashes

Hashes for thermoclaw-0.1.4.tar.gz
Algorithm Hash digest
SHA256 e1391e8ea5cb88c7bdff17985557a76deb5693b50ab596501ef0fe31f8e8c2e8
MD5 3ad4ae9655aa4fee1fbb4c81640dc2d3
BLAKE2b-256 50c7c3344714b67b3576990cb953fe9f472d6827afae4831239c3a74bce6d5c4

See more details on using hashes here.

File details

Details for the file thermoclaw-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: thermoclaw-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 54.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.7

File hashes

Hashes for thermoclaw-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 204067d4d8be4a0d8ac75cc0042916415ce92db9fcf06d64ac05822c5b25e484
MD5 fb0ed1c03c5bc1a252345c16684cce65
BLAKE2b-256 0072f633577889f8cf85760400576f70d4f256348f97a5e5d961c6fa6710a896

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page