Thermodynamic diagnostics for neural network training — see what your optimizer is actually doing.
Project description
Thermoclaw
Thermodynamic diagnostics for neural network training — see what your optimiser is actually doing.
Thermoclaw wraps any PyTorch optimizer and exposes real-time thermodynamic quantities: entropy production, internal vs external entropy flow, weight-decay collapse detection, equilibrium detection, and actionable per-layer diagnostics.
It found the problem. Acting on it worked.
We ran Thermoclaw on a Pythia-410M fine-tuning run with weight_decay=0.3 on WikiText-103.
At step 39, CollapseDetector flagged weight decay collapse in the attention layers:
[MEDIUM] Weight decay collapse in 5 attention layers: grad/param ratio
dropped 3.1x from early to late training. Reduce weight_decay.
We reduced weight_decay from 0.3 to 0.01 at that point and kept training.
| Run | Final PPL (steps 280-300) |
|---|---|
| Unmodified (wd=0.3 throughout) | 256.8 |
| Thermoclaw intervention at step 39 | 222.9 |
| Improvement | +33.8 PPL |
No hyperparameter search. No manual inspection. One warning, one change.
That's what Thermoclaw does. The warning is causal — not correlational.
Install
pip install thermoclaw # core (torch + numpy only)
pip install thermoclaw[viz] # + matplotlib dashboards
pip install thermoclaw[hf] # + HuggingFace Trainer callback
pip install thermoclaw[all] # everything
Quick Start
Observe any optimizer
from thermoclaw import Observer, diagnose
model = YourModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
observer = Observer(model, optimizer)
for batch in loader:
loss = criterion(model(batch))
loss.backward()
optimizer.step()
observer.step(loss=loss.item())
optimizer.zero_grad()
report = diagnose(observer)
print(report)
observer.plot_dashboard(save_path='dashboard.png')
d_iS / d_eS entropy split (the killer feature)
from thermoclaw import Observer, EntropySplit, diagnose
observer = Observer(model, optimizer)
splitter = EntropySplit(model, optimizer, observer)
for batch in loader:
loss = criterion(model(batch))
loss.backward()
optimizer.step()
observer.step(loss=loss.item())
splitter.step() # ← decomposes entropy
optimizer.zero_grad()
report = diagnose(observer, splitter)
print(report)
# Output:
# [HIGH] Weight decay is the dominant entropy source for 12 attention
# layers (mean R_ie=4.2). Consider reducing weight_decay for attention
# layers by 4-8×, or excluding them.
splitter.plot_entropy_split(save_path='entropy_split.png')
CollapseDetector — catch weight-decay collapse in real time
Weight decay can silently erode parameter capacity faster than gradient signal can rebuild it. The observable signature is ‖g‖/‖θ‖ (grad/param ratio) dropping over time. CollapseDetector tracks this directly — no calibration period required, no incommensurable quantities.
from thermoclaw import CollapseDetector
detector = CollapseDetector(model, optimizer)
for batch in loader:
loss.backward()
optimizer.step()
detector.step() # ← call before zero_grad
optimizer.zero_grad()
# Mid-training intervention trigger
if detector.is_collapsing:
for pg in optimizer.param_groups:
pg['weight_decay'] *= 0.1
# Post-training report
recs = detector.get_recommendations()
# → ["[HIGH] Weight decay collapse in 5 mlp layers: grad/param ratio
# dropped 4.2× from early to late training. Reduce weight_decay."]
is_collapsing returns True as soon as any HIGH or MEDIUM collapse signal fires — use it for automated mid-training interventions. Validated at Pythia-410M scale: firing correctly at step 39 with wd=0.3, PPL gap +33.8 vs unmodified arm.
ThermoScheduler (drop-in replacement)
One line to replace your cosine scheduler. Keeps your optimizer, just makes the schedule thermodynamically aware:
from thermoclaw import ThermoScheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = ThermoScheduler(optimizer, total_steps=10000)
for step, batch in enumerate(loader):
loss = criterion(model(batch))
loss.backward()
optimizer.step()
scheduler.step() # ← replaces cosine_scheduler.step()
optimizer.zero_grad()
HuggingFace integration (one line)
from thermoclaw.integrations.huggingface import ThermoclawCallback
trainer = Trainer(
model=model,
args=training_args,
callbacks=[ThermoclawCallback()], # ← that's it
)
trainer.train()
# Dashboard PNG, CSV, and diagnostic report saved to output_dir automatically
What Thermoclaw measures
| Quantity | Symbol | What it means |
|---|---|---|
| Entropy production | σ_l = η‖g‖² | How much "thermodynamic work" each layer is doing |
| Entropy ratio | r_l = σ/σ* | 1.0 = equilibrium. <0.85 = under-trained. >1.15 = over-trained |
| External entropy | d_eS | Entropy that reduces loss (productive learning) |
| Internal entropy | d_iS | Entropy from weight decay, momentum, noise (overhead) |
| R_ie = d_iS/d_eS | — | The diagnostic ratio. >2 = warning, >5 = critical |
| Grad/param ratio | ‖g‖/‖θ‖ | CollapseDetector signal. Dropping trend = weight decay eroding capacity |
| Dispersion | D = Var(r_l) | Inter-layer training uniformity |
| Gradient alignment | ρ = cos(g_t, g_{t-1}) | Step coherence. Negative = oscillation |
| Parameter distance | E = ‖θ−θ₀‖² | How far weights have moved from init |
A note on ρ and momentum (non-obvious)
High momentum (β₁ → 1) increases ρ, not decreases it — momentum smooths consecutive gradient vectors, making them more correlated. This means ρ is not a reliable signal for over-damped momentum. If β₁ is too high, use equilibrium fraction instead: over-damped runs show low eq_fraction despite high ρ, because damped updates can't reach homeostatic σ* levels. Validated at Pythia-410M scale: β₁=0.999 produced ρ=0.63 vs baseline ρ=0.39, with eq_fraction=0.15 vs 0.59.
The d_iS / d_eS decomposition
Standard training observes total loss and calls it a day. But total entropy production σ conflates two fundamentally different thermodynamic flows:
- d_eS (external) — gradient-driven parameter updates that reduce loss. This is productive work.
- d_iS (internal) — entropy from weight decay, momentum friction, stochastic noise. This is heat.
When d_iS >> d_eS, the optimizer is spending most of its entropy budget on overhead. Thermoclaw decomposes d_iS further into:
- d_iS_wd — weight decay contribution
- d_iS_momentum — momentum friction
- d_iS_noise — stochastic gradient noise
This tells you exactly which layers, at which step, are wasting compute — and why.
Per-layer param groups
For full per-layer resolution, use make_param_groups:
from thermoclaw import make_param_groups
groups = make_param_groups(model, lr=3e-4, weight_decay=0.01)
optimizer = torch.optim.AdamW(groups)
observer = Observer(model, optimizer)
Confidence scoring
Recommendations are conservative. Thermoclaw only flags issues where the physics signal is unambiguous:
- [HIGH] — Single dominant source (>60% of d_iS), R_ie > 5, consistent across regions. Safe to act on.
- [MEDIUM] — Clear signal but moderate R_ie (2-5). Worth investigating.
- [LOW] — Signal present but multiple sources contribute. Informational only.
Wrong recommendations that sound authoritative destroy trust faster than no recommendations at all. Thermoclaw would rather under-claim than over-claim.
Validated
Three-tier validation on H100 80GB (Pythia-410M, WikiText-103, bfloat16):
| Tier | Test | Result |
|---|---|---|
| T1: Analytical | σ, ρ, d_iS_wd, d_eS+d_iS=σ, E, D — 8 ground-truth checks | 8/8 PASS |
| T2A: High LR | lr=3e-2 → R_ie=1.6×10²⁰, eq=0.014, flagged HIGH | PASS |
| T2B: High WD | wd=5.0 → unhealthy, flagged MEDIUM | PASS |
| T2C: Overdamped | β₁=0.999 → ρ=0.63 (vs baseline 0.39), eq=0.15, flagged HIGH | PASS |
| T2D: Baseline | lr=3e-4/wd=0.01/β₁=0.9 → no collapse or WD pathology recs | PASS |
| T3: Intervention | CollapseDetector fires step 39 (wd=0.3), PPL gap +33.8 vs unmodified | PASS |
Origin
Thermoclaw's thermodynamic framework comes from the EPTO (Entropy-Production Targeted Optimisation) research project. The key insight: neural network training is a non-equilibrium thermodynamic process, and the quantities that matter (entropy production, entropy ratios, equilibrium fraction) can be measured for any optimizer, not just EPTO.
License
Apache 2.0
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file thermoclaw-0.1.1.tar.gz.
File metadata
- Download URL: thermoclaw-0.1.1.tar.gz
- Upload date:
- Size: 41.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
439269dc841e0aab0154cedf66bd48ff537c725b543a33730eaa5ecb14a81b86
|
|
| MD5 |
53a26b44f7c2e869a030e9d022024cbd
|
|
| BLAKE2b-256 |
1da83b62d4b123f038dab4909b32313b36b143fc389c730d378a0140432226df
|
File details
Details for the file thermoclaw-0.1.1-py3-none-any.whl.
File metadata
- Download URL: thermoclaw-0.1.1-py3-none-any.whl
- Upload date:
- Size: 35.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
38b297590f68d3f13d1fe77fc9a416e2d50704ac32ee4a8393fafec0706c7667
|
|
| MD5 |
581b4a8e04df2893bf7c4242d831a239
|
|
| BLAKE2b-256 |
0d2e6d066108e078cc400d4827b750ad82614af3680001a94423687d15ce9fba
|