TouchHook: A PyTorch hook management library
Project description
TorchHook
TorchHook is a library for managing PyTorch model hooks, providing convenient interfaces to capture feature maps and debug models.
Installation
pip install torchhook
Usage Example
import torch
import torch.nn as nn
from torchhook import HookManager
# Define a simple model
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3)
self.relu = nn.ReLU()
self.fc = nn.Linear(16 * 30 * 30, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# Initialize model and HookManager
model = MyModel()
hook_manager = HookManager(model)
# Register hooks using layer_name (recommended for simplicity)
hook_manager.register_forward_hook(layer_name="conv1")
# Register hooks using layer object (automatically named as: ClassName+Index)
hook_manager.register_forward_hook(layer=model.relu)
# Register hooks with a custom name (useful for distinguishing hooks when debugging)
hook_manager.register_forward_hook('CustomName', layer=model.fc)
# Run the model
for _ in range(5):
# Generate random input data
input_tensor = torch.randn(2, 3, 32, 32)
output = model(input_tensor)
# Print HookManager information
print(hook_manager)
print("Current keys:", hook_manager.get_keys()) # Get all registered hook names
# Get intermediate results (feature maps)
print("\nconv1:", hook_manager.get_features('conv1')[0].shape) # Feature map of conv1
print(" fc:", hook_manager.get_features('CustomName')[0].shape) # Feature map of fc
# Get all feature maps
all_features = hook_manager.get_all()
# Concatenate feature maps for each layer (may cause memory overflow if data is too large)
concatenated_features = {key: torch.cat(features, dim=0) for key, features in all_features.items()}
# Compute mean and standard deviation
stats = {key: (torch.mean(value), torch.std(value)) for key, value in concatenated_features.items()}
# Print results
print("\nMean and Std of features:")
for key, (mean, std) in stats.items():
print(f"Layer: {key}, Mean: {mean.item():.4f}, Std: {std.item():.4f}")
# Clear hooks and features
hook_manager.clear_hooks()
hook_manager.clear_features()
Example Output:
Model: MyModel
Layer Name Feature Count Feature Shape
--------------------------------------------------------------------------------
conv1 5 (2, 16, 30, 30)
ReLU_0 5 (2, 16, 30, 30)
CustomName 5 (2, 10)
--------------------------------------------------------------------------------
Current keys: ['conv1', 'ReLU_0', 'CustomName']
conv1: torch.Size([2, 16, 30, 30])
fc: torch.Size([2, 10])
Mean and Std of features:
Layer: conv1, Mean: -0.0060, Std: 0.5978
Layer: ReLU_0, Mean: 0.2344, Std: 0.3463
Layer: CustomName, Mean: 0.0245, Std: 0.2332
License
MIT License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchhook-0.1.0.tar.gz
(6.1 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchhook-0.1.0.tar.gz.
File metadata
- Download URL: torchhook-0.1.0.tar.gz
- Upload date:
- Size: 6.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3a58b2e4d823f8cf3ddf3d160e24ae2fe1c6a49d824e2f370de760a81a81e572
|
|
| MD5 |
9e6e295469ba63c97b6c6738eab7d089
|
|
| BLAKE2b-256 |
81907c5c3fb1809e04dae63890c1ebdaf104ebeec64249de100f9f336f190be6
|
File details
Details for the file torchhook-0.1.0-py3-none-any.whl.
File metadata
- Download URL: torchhook-0.1.0-py3-none-any.whl
- Upload date:
- Size: 6.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e897f1ee895abfea8722264ba61e6fc2128cfc3594cb924236ce8aa59591c953
|
|
| MD5 |
1f860a9182cc3da567939a7261520b3b
|
|
| BLAKE2b-256 |
467164ef4f14a7f629bd9495e28dacf4a1758d41b2280e1b84b5858df5955ddb
|