Skip to main content

Thermodynamic diagnostics for neural network training — see what your optimizer is actually doing.

Project description

Thermoclaw

Thermodynamic diagnostics for neural network training — see what your optimiser is actually doing.

Thermoclaw wraps any PyTorch optimizer and exposes real-time thermodynamic quantities: entropy production, internal vs external entropy flow, weight-decay collapse detection, equilibrium detection, and actionable per-layer diagnostics.


It found the problem. Acting on it worked.

We ran Thermoclaw on a Pythia-410M fine-tuning run with weight_decay=0.3 on WikiText-103. At step 39, CollapseDetector flagged weight decay collapse in the attention layers:

[MEDIUM] Weight decay collapse in 5 attention layers: grad/param ratio
dropped 3.1x from early to late training. Reduce weight_decay.

We reduced weight_decay from 0.3 to 0.01 at that point and kept training.

Run Final PPL (steps 280-300)
Unmodified (wd=0.3 throughout) 256.8
Thermoclaw intervention at step 39 222.9
Improvement +33.8 PPL

No hyperparameter search. No manual inspection. One warning, one change.

That's what Thermoclaw does. The warning is causal — not correlational.


Install

pip install thermoclaw                    # core (torch + numpy only)
pip install thermoclaw[viz]               # + matplotlib dashboards
pip install thermoclaw[hf]                # + HuggingFace Trainer callback
pip install thermoclaw[all]               # everything

Quick Start

Observe any optimizer

from thermoclaw import Observer, diagnose

model = YourModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
observer = Observer(model, optimizer)

for batch in loader:
    loss = criterion(model(batch))
    loss.backward()
    optimizer.step()
    observer.step(loss=loss.item())
    optimizer.zero_grad()

report = diagnose(observer)
print(report)
observer.plot_dashboard(save_path='dashboard.png')

d_iS / d_eS entropy split (the killer feature)

from thermoclaw import Observer, EntropySplit, diagnose

observer = Observer(model, optimizer)
splitter = EntropySplit(model, optimizer, observer)

for batch in loader:
    loss = criterion(model(batch))
    loss.backward()
    optimizer.step()
    observer.step(loss=loss.item())
    splitter.step()                        # ← decomposes entropy
    optimizer.zero_grad()

report = diagnose(observer, splitter)
print(report)
# Output:
#   [HIGH] Weight decay is the dominant entropy source for 12 attention
#   layers (mean R_ie=4.2). Consider reducing weight_decay for attention
#   layers by 4-8×, or excluding them.

splitter.plot_entropy_split(save_path='entropy_split.png')

CollapseDetector — catch weight-decay collapse in real time

Weight decay can silently erode parameter capacity faster than gradient signal can rebuild it. The observable signature is ‖g‖/‖θ‖ (grad/param ratio) dropping over time. CollapseDetector tracks this directly — no calibration period required, no incommensurable quantities.

from thermoclaw import CollapseDetector

detector = CollapseDetector(model, optimizer)

for batch in loader:
    loss.backward()
    optimizer.step()
    detector.step()                        # ← call before zero_grad
    optimizer.zero_grad()

    # Mid-training intervention trigger
    if detector.is_collapsing:
        for pg in optimizer.param_groups:
            pg['weight_decay'] *= 0.1

# Post-training report
recs = detector.get_recommendations()
# → ["[HIGH] Weight decay collapse in 5 mlp layers: grad/param ratio
#     dropped 4.2× from early to late training. Reduce weight_decay."]

is_collapsing returns True as soon as any HIGH or MEDIUM collapse signal fires — use it for automated mid-training interventions. Validated at Pythia-410M scale: firing correctly at step 39 with wd=0.3, PPL gap +33.8 vs unmodified arm.

ThermoScheduler (drop-in replacement)

One line to replace your cosine scheduler. Keeps your optimizer, just makes the schedule thermodynamically aware:

from thermoclaw import ThermoScheduler

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = ThermoScheduler(optimizer, total_steps=10000)

for step, batch in enumerate(loader):
    loss = criterion(model(batch))
    loss.backward()
    optimizer.step()
    scheduler.step()          # ← replaces cosine_scheduler.step()
    optimizer.zero_grad()

HuggingFace integration (one line)

from thermoclaw.integrations.huggingface import ThermoclawCallback

trainer = Trainer(
    model=model,
    args=training_args,
    callbacks=[ThermoclawCallback()],      # ← that's it
)
trainer.train()
# Dashboard PNG, CSV, and diagnostic report saved to output_dir automatically

What Thermoclaw measures

Quantity Symbol What it means
Entropy production σ_l = η‖g‖² How much "thermodynamic work" each layer is doing
Entropy ratio r_l = σ/σ* 1.0 = equilibrium. <0.85 = under-trained. >1.15 = over-trained
External entropy d_eS Entropy that reduces loss (productive learning)
Internal entropy d_iS Entropy from weight decay, momentum, noise (overhead)
R_ie = d_iS/d_eS The diagnostic ratio. >2 = warning, >5 = critical
Grad/param ratio ‖g‖/‖θ‖ CollapseDetector signal. Dropping trend = weight decay eroding capacity
Dispersion D = Var(r_l) Inter-layer training uniformity
Gradient alignment ρ = cos(g_t, g_{t-1}) Step coherence. Negative = oscillation
Parameter distance E = ‖θ−θ₀‖² How far weights have moved from init

A note on ρ and momentum (non-obvious)

High momentum (β₁ → 1) increases ρ, not decreases it — momentum smooths consecutive gradient vectors, making them more correlated. This means ρ is not a reliable signal for over-damped momentum. If β₁ is too high, use equilibrium fraction instead: over-damped runs show low eq_fraction despite high ρ, because damped updates can't reach homeostatic σ* levels. Validated at Pythia-410M scale: β₁=0.999 produced ρ=0.63 vs baseline ρ=0.39, with eq_fraction=0.15 vs 0.59.

The d_iS / d_eS decomposition

Standard training observes total loss and calls it a day. But total entropy production σ conflates two fundamentally different thermodynamic flows:

  • d_eS (external) — gradient-driven parameter updates that reduce loss. This is productive work.
  • d_iS (internal) — entropy from weight decay, momentum friction, stochastic noise. This is heat.

When d_iS >> d_eS, the optimizer is spending most of its entropy budget on overhead. Thermoclaw decomposes d_iS further into:

  • d_iS_wd — weight decay contribution
  • d_iS_momentum — momentum friction
  • d_iS_noise — stochastic gradient noise

This tells you exactly which layers, at which step, are wasting compute — and why.

Per-layer param groups

For full per-layer resolution, use make_param_groups:

from thermoclaw import make_param_groups

groups = make_param_groups(model, lr=3e-4, weight_decay=0.01)
optimizer = torch.optim.AdamW(groups)
observer = Observer(model, optimizer)

Confidence scoring

Recommendations are conservative. Thermoclaw only flags issues where the physics signal is unambiguous:

  • [HIGH] — Single dominant source (>60% of d_iS), R_ie > 5, consistent across regions. Safe to act on.
  • [MEDIUM] — Clear signal but moderate R_ie (2-5). Worth investigating.
  • [LOW] — Signal present but multiple sources contribute. Informational only.

Wrong recommendations that sound authoritative destroy trust faster than no recommendations at all. Thermoclaw would rather under-claim than over-claim.

Validated

Three-tier validation on H100 80GB (Pythia-410M, WikiText-103, bfloat16):

Tier Test Result
T1: Analytical σ, ρ, d_iS_wd, d_eS+d_iS=σ, E, D — 8 ground-truth checks 8/8 PASS
T2A: High LR lr=3e-2 → R_ie=1.6×10²⁰, eq=0.014, flagged HIGH PASS
T2B: High WD wd=5.0 → unhealthy, flagged MEDIUM PASS
T2C: Overdamped β₁=0.999 → ρ=0.63 (vs baseline 0.39), eq=0.15, flagged HIGH PASS
T2D: Baseline lr=3e-4/wd=0.01/β₁=0.9 → no collapse or WD pathology recs PASS
T3: Intervention CollapseDetector fires step 39 (wd=0.3), PPL gap +33.8 vs unmodified PASS

Origin

Thermoclaw's thermodynamic framework comes from the EPTO (Entropy-Production Targeted Optimisation) research project. The key insight: neural network training is a non-equilibrium thermodynamic process, and the quantities that matter (entropy production, entropy ratios, equilibrium fraction) can be measured for any optimizer, not just EPTO.

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

thermoclaw-0.1.0.tar.gz (41.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

thermoclaw-0.1.0-py3-none-any.whl (34.8 kB view details)

Uploaded Python 3

File details

Details for the file thermoclaw-0.1.0.tar.gz.

File metadata

  • Download URL: thermoclaw-0.1.0.tar.gz
  • Upload date:
  • Size: 41.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.7

File hashes

Hashes for thermoclaw-0.1.0.tar.gz
Algorithm Hash digest
SHA256 bab27a50b913174d2f20a7fb3c2119264ef434fda178ec5670d2afebc4c6ccb1
MD5 a2bdba197924890cae0a288af6657a01
BLAKE2b-256 df695093d5334d7d050d738ae8aaaf6c03039dc2329215e9bbbcc6598b0e6cc7

See more details on using hashes here.

File details

Details for the file thermoclaw-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: thermoclaw-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 34.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.7

File hashes

Hashes for thermoclaw-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 dee3040a9b5dcb60c1de3467314d06c5fc4453ff60eb238b0144b2e19d2f4c49
MD5 8da532efdfdd008e6635158f8c690bfd
BLAKE2b-256 e677d69cb147f05d5f9c40741e7c7b4b3dcb178f83b284988c8dcbaeb57527eb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page