Diagnostics for PyTorch model training - monitor activations, parameters, and gradients
Project description
Monitor the training of your PyTorch modules
A simple research code base to monitor the training of small-to-medium neural networks. Log arbitrary metrics of activations, gradients, and parameters to W&B with a few lines of code!
We also implement the refined coordinate check (RCC) from the NeurIPS 2025 paper "On the Surprising Effectiveness of Large Learning Rates under Standard Width Scaling" (Haas et al., 2025).
⚡For a complete working example, see how the monitor can be integrated into nanoGPT.⚡
Installation
pip install torch-module-monitor
Features
1. Monitor arbitrary metrics of activations, gradients, and parameters with a few lines of code
- Add new metrics for activations, gradients, and parameters with a single line of code.
- Regex-based filtering to determine what should be logged
- Monitor the internals of the attention operation (query/key/value tensor metrics, attention entropy)
- Aggregation of activation metrics across micro-batches
2. Perform the Refined Coordinate Check (RCC) from https://arxiv.org/abs/2505.22491
- We provide an implementation of the refined coordinate check.
Basic Monitoring
from torch_module_monitor import ModuleMonitor
# Initialize and add metrics
monitor = ModuleMonitor(monitor_step_fn=lambda step: step % 10 == 0)
monitor.set_module(model)
monitor.add_activation_metric("mean", lambda x: x.mean())
monitor.add_parameter_metric("norm", lambda x: x.norm())
monitor.add_gradient_metric("norm", lambda x: x.norm())
# Training loop
for step, (inputs, targets) in enumerate(dataloader):
monitor.begin_step(step)
outputs = model(inputs) # Activations captured via hooks
loss = criterion(outputs, targets)
loss.backward()
monitor.monitor_parameters()
monitor.monitor_gradients()
optimizer.step()
optimizer.zero_grad()
monitor.end_step()
# Log metrics
if monitor.is_step_monitored(step):
wandb.log(monitor.get_step_metrics())
Complete Examples
See examples/ for complete examples:
metrics.ipynb- Basic metric monitoringreference-model.ipynb- Reference module comparisonrefined-coordinate-check.ipynb- Refined coordinate check
⚡We also show how to integrate the monitor into nanoGPT.⚡
Integration with Weights & Biases
We name the different metrics such that they are nicely visualized in Weights & Biases.
Log the collected metrics in a single line of code:
wandb.log(training_monitor.get_step_metrics(), step=current_step)
Examples: TODO Provide Links
Patterns
Regex-Based Module Filtering
You can use a regex to specify that a metric should only be computed for specific tensors.
# Monitor only attention layers
monitor.add_activation_metric(
"my_metric", my_metric(x), metric_regex=r".*mlp.*"
)
Reference Module Comparison
In infinite width theory, we often want to measure the difference of activations and parameters to the model at initialization. We implement this via an arbitrary reference model to which our model can be compared.
monitor.set_reference_module(reference_model)
# Track drift from initialization
monitor.add_parameter_difference_metric(
"l2_distance", lambda p, p_ref: (p - p_ref).norm()
)
Complex Modules
By default, we monitor the activations of modules that return a single tensor. To monitor statistics of complex modules, these modules can implement MonitorMixin. We use this approach to monitor the internals of the attention operation.
from torch_module_monitor import MonitorMixin, monitor_scaled_dot_product_attention
class MultiHeadAttention(nn.Module, MonitorMixin):
def forward(self, x):
q, k, v = self.compute_qkv(x)
attn_output = F.scaled_dot_product_attention(q, k, v)
if self.is_monitoring:
monitor_scaled_dot_product_attention(
self.get_module_monitor(), module=self,
query=q, key=k, value=v, activation=attn_output
)
return self.output_projection(attn_output)
This logs per-head metrics: activation/{module}.head_{i}.query, attention_entropy/{module}.head_{i}, etc.
Custom metrics in any module:
from torch_module_monitor import MonitorMixin
class CustomLayer(nn.Module, MonitorMixin):
def forward(self, x):
output = self.transform(x)
if self.is_monitoring:
self.get_module_monitor().log_tensor("custom_stat", output.norm(dim=-1))
return output
Multi-GPU Support
In principle, the monitor can support multi-GPU training, though we do not provide direct support for any parallelization strategy. With FSDP, for example, every GPU could have its own monitor. However, we do not currently implement the synchronization of activation metrics across GPUs. The refined coordinate check was only tested for single-GPU training.
Citation
If you use this code, please cite:
@inproceedings{haas2025splargelr,
title={On the Surprising Effectiveness of Large Learning Rates under Standard Width Scaling},
author={Haas, Moritz and Bordt, Sebastian and von Luxburg, Ulrike and Vankadara, Leena Chennuru},
booktitle={Advances in Neural Information Processing Systems 38},
year={2025}
}
Contributing
We provide this code as-is. We may accept pull requests that fix bugs or add new features.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_module_monitor-0.1.0.tar.gz.
File metadata
- Download URL: torch_module_monitor-0.1.0.tar.gz
- Upload date:
- Size: 26.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cdaec9bff1d6e7f494c31f53dd77a65a01a218969ca9e6ebd56c83cbb7554cbd
|
|
| MD5 |
53ece6032c4475c0358570cdf9f82784
|
|
| BLAKE2b-256 |
7c9ad06fdf798c087e52ba7ec19e252110674959b4b6ca40a43f12d16bc02591
|
File details
Details for the file torch_module_monitor-0.1.0-py3-none-any.whl.
File metadata
- Download URL: torch_module_monitor-0.1.0-py3-none-any.whl
- Upload date:
- Size: 21.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bf723f2ac6fc1f227431c8c76ced268a36675aa47869f151888340aec8831cd4
|
|
| MD5 |
6370df8ad4630dd79d6239c89fb6b9ad
|
|
| BLAKE2b-256 |
9cde7e7ac9a84aef174a128e2d1e7c71a356cbfb4c030ac4e405308fa40ffea5
|