Skip to main content

Thermodynamic diagnostics for neural network training — see what your optimizer is actually doing.

Project description

Thermoclaw

Thermodynamic diagnostics for neural network training — see what your optimiser is actually doing.

Thermoclaw wraps any PyTorch optimizer and exposes real-time thermodynamic quantities: entropy production, internal vs external entropy flow, weight-decay collapse detection, equilibrium detection, and actionable per-layer diagnostics.


It found the problem. Acting on it worked.

We ran Thermoclaw on a Pythia-410M fine-tuning run with weight_decay=0.3 on WikiText-103. At step 39, CollapseDetector flagged weight decay collapse in the attention layers:

[MEDIUM] Weight decay collapse in 5 attention layers: grad/param ratio
dropped 3.1x from early to late training. Reduce weight_decay.

We reduced weight_decay from 0.3 to 0.01 at that point and kept training.

Run Final PPL (steps 280-300)
Unmodified (wd=0.3 throughout) 256.8
Thermoclaw intervention at step 39 222.9
Improvement +33.8 PPL

No hyperparameter search. No manual inspection. One warning, one change.

That's what Thermoclaw does. The warning is causal — not correlational.


Install

pip install thermoclaw                    # core (torch + numpy only)
pip install thermoclaw[viz]               # + matplotlib dashboards
pip install thermoclaw[hf]                # + HuggingFace Trainer callback
pip install thermoclaw[all]               # everything

Quick Start

Observe any optimizer

from thermoclaw import Observer, diagnose

model = YourModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
observer = Observer(model, optimizer)

for batch in loader:
    loss = criterion(model(batch))
    loss.backward()
    optimizer.step()
    observer.step(loss=loss.item())
    optimizer.zero_grad()

report = diagnose(observer)
print(report)
observer.plot_dashboard(save_path='dashboard.png')

d_iS / d_eS entropy split (the killer feature)

from thermoclaw import Observer, EntropySplit, diagnose

observer = Observer(model, optimizer)
splitter = EntropySplit(model, optimizer, observer)

for batch in loader:
    loss = criterion(model(batch))
    loss.backward()
    optimizer.step()
    observer.step(loss=loss.item())
    splitter.step()                        # ← decomposes entropy
    optimizer.zero_grad()

report = diagnose(observer, splitter)
print(report)
# Output:
#   [HIGH] Weight decay is the dominant entropy source for 12 attention
#   layers (mean R_ie=4.2). Consider reducing weight_decay for attention
#   layers by 4-8×, or excluding them.

splitter.plot_entropy_split(save_path='entropy_split.png')

CollapseDetector — catch weight-decay collapse in real time

Weight decay can silently erode parameter capacity faster than gradient signal can rebuild it. The observable signature is ‖g‖/‖θ‖ (grad/param ratio) dropping over time. CollapseDetector tracks this directly — no calibration period required, no incommensurable quantities.

from thermoclaw import CollapseDetector

detector = CollapseDetector(model, optimizer)

for batch in loader:
    loss.backward()
    optimizer.step()
    detector.step()                        # ← call before zero_grad
    optimizer.zero_grad()

    # Mid-training intervention trigger
    if detector.is_collapsing:
        for pg in optimizer.param_groups:
            pg['weight_decay'] *= 0.1

# Post-training report
recs = detector.get_recommendations()
# → ["[HIGH] Weight decay collapse in 5 mlp layers: grad/param ratio
#     dropped 4.2× from early to late training. Reduce weight_decay."]

is_collapsing returns True as soon as any HIGH or MEDIUM collapse signal fires — use it for automated mid-training interventions. Validated at Pythia-410M scale: firing correctly at step 39 with wd=0.3, PPL gap +33.8 vs unmodified arm.

ThermoScheduler (drop-in replacement)

One line to replace your cosine scheduler. Keeps your optimizer, just makes the schedule thermodynamically aware:

from thermoclaw import ThermoScheduler

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = ThermoScheduler(optimizer, total_steps=10000)

for step, batch in enumerate(loader):
    loss = criterion(model(batch))
    loss.backward()
    optimizer.step()
    scheduler.step()          # ← replaces cosine_scheduler.step()
    optimizer.zero_grad()

HuggingFace integration (one line)

from thermoclaw.integrations.huggingface import ThermoclawCallback

trainer = Trainer(
    model=model,
    args=training_args,
    callbacks=[ThermoclawCallback()],      # ← that's it
)
trainer.train()
# Dashboard PNG, CSV, and diagnostic report saved to output_dir automatically

What Thermoclaw measures

Quantity Symbol What it means
Entropy production σ_l = η‖g‖² How much "thermodynamic work" each layer is doing
Entropy ratio r_l = σ/σ* 1.0 = equilibrium. <0.85 = under-trained. >1.15 = over-trained
External entropy d_eS Entropy that reduces loss (productive learning)
Internal entropy d_iS Entropy from weight decay, momentum, noise (overhead)
R_ie = d_iS/d_eS The diagnostic ratio. >2 = warning, >5 = critical
Grad/param ratio ‖g‖/‖θ‖ CollapseDetector signal. Dropping trend = weight decay eroding capacity
Dispersion D = Var(r_l) Inter-layer training uniformity
Gradient alignment ρ = cos(g_t, g_{t-1}) Step coherence. Negative = oscillation
Parameter distance E = ‖θ−θ₀‖² How far weights have moved from init

A note on ρ and momentum (non-obvious)

High momentum (β₁ → 1) increases ρ, not decreases it — momentum smooths consecutive gradient vectors, making them more correlated. This means ρ is not a reliable signal for over-damped momentum. If β₁ is too high, use equilibrium fraction instead: over-damped runs show low eq_fraction despite high ρ, because damped updates can't reach homeostatic σ* levels. Validated at Pythia-410M scale: β₁=0.999 produced ρ=0.63 vs baseline ρ=0.39, with eq_fraction=0.15 vs 0.59.

The d_iS / d_eS decomposition

Standard training observes total loss and calls it a day. But total entropy production σ conflates two fundamentally different thermodynamic flows:

  • d_eS (external) — gradient-driven parameter updates that reduce loss. This is productive work.
  • d_iS (internal) — entropy from weight decay, momentum friction, stochastic noise. This is heat.

When d_iS >> d_eS, the optimizer is spending most of its entropy budget on overhead. Thermoclaw decomposes d_iS further into:

  • d_iS_wd — weight decay contribution
  • d_iS_momentum — momentum friction
  • d_iS_noise — stochastic gradient noise

This tells you exactly which layers, at which step, are wasting compute — and why.

Per-layer param groups

For full per-layer resolution, use make_param_groups:

from thermoclaw import make_param_groups

groups = make_param_groups(model, lr=3e-4, weight_decay=0.01)
optimizer = torch.optim.AdamW(groups)
observer = Observer(model, optimizer)

Confidence scoring

Recommendations are conservative. Thermoclaw only flags issues where the physics signal is unambiguous:

  • [HIGH] — Single dominant source (>60% of d_iS), R_ie > 5, consistent across regions. Safe to act on.
  • [MEDIUM] — Clear signal but moderate R_ie (2-5). Worth investigating.
  • [LOW] — Signal present but multiple sources contribute. Informational only.

Wrong recommendations that sound authoritative destroy trust faster than no recommendations at all. Thermoclaw would rather under-claim than over-claim.

Validated

Three-tier validation on H100 80GB (Pythia-410M, WikiText-103, bfloat16):

Tier Test Result
T1: Analytical σ, ρ, d_iS_wd, d_eS+d_iS=σ, E, D — 8 ground-truth checks 8/8 PASS
T2A: High LR lr=3e-2 → R_ie=1.6×10²⁰, eq=0.014, flagged HIGH PASS
T2B: High WD wd=5.0 → unhealthy, flagged MEDIUM PASS
T2C: Overdamped β₁=0.999 → ρ=0.63 (vs baseline 0.39), eq=0.15, flagged HIGH PASS
T2D: Baseline lr=3e-4/wd=0.01/β₁=0.9 → no collapse or WD pathology recs PASS
T3: Intervention CollapseDetector fires step 39 (wd=0.3), PPL gap +33.8 vs unmodified PASS

Origin

Thermoclaw's thermodynamic framework comes from the EPTO (Entropy-Production Targeted Optimisation) research project. The key insight: neural network training is a non-equilibrium thermodynamic process, and the quantities that matter (entropy production, entropy ratios, equilibrium fraction) can be measured for any optimizer, not just EPTO.

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

thermoclaw-0.1.1.tar.gz (41.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

thermoclaw-0.1.1-py3-none-any.whl (35.0 kB view details)

Uploaded Python 3

File details

Details for the file thermoclaw-0.1.1.tar.gz.

File metadata

  • Download URL: thermoclaw-0.1.1.tar.gz
  • Upload date:
  • Size: 41.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.7

File hashes

Hashes for thermoclaw-0.1.1.tar.gz
Algorithm Hash digest
SHA256 439269dc841e0aab0154cedf66bd48ff537c725b543a33730eaa5ecb14a81b86
MD5 53a26b44f7c2e869a030e9d022024cbd
BLAKE2b-256 1da83b62d4b123f038dab4909b32313b36b143fc389c730d378a0140432226df

See more details on using hashes here.

File details

Details for the file thermoclaw-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: thermoclaw-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 35.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.7

File hashes

Hashes for thermoclaw-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 38b297590f68d3f13d1fe77fc9a416e2d50704ac32ee4a8393fafec0706c7667
MD5 581b4a8e04df2893bf7c4242d831a239
BLAKE2b-256 0d2e6d066108e078cc400d4827b750ad82614af3680001a94423687d15ce9fba

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page