TouchHook: A PyTorch hook management library
Project description
TorchHook
TorchHook is a library for managing PyTorch model hooks, providing convenient interfaces to capture feature maps and debug models.
Features
- Easy Hook Management: Simplify the process of registering and managing hooks in PyTorch models.
- Feature Map Extraction: Capture intermediate feature maps for analysis and debugging.
- Customizable: Support for custom hook names and flexible usage.
Installation
pip install torchhook
Usage Example
import torch
import torch.nn as nn
from torchhook import HookManager
# Define a simple model
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3)
self.relu = nn.ReLU()
self.fc = nn.Linear(16 * 30 * 30, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# Initialize model and HookManager
model = MyModel()
hook_manager = HookManager(model)
# Register hooks using layer_name (recommended for simplicity)
hook_manager.register_forward_hook(layer_name="conv1")
# Register hooks using layer object (automatically named as: ClassName+Index)
hook_manager.register_forward_hook(layer=model.relu)
# Register hooks with a custom name (useful for distinguishing hooks when debugging)
hook_manager.register_forward_hook('CustomName', layer=model.fc)
# Run the model
for _ in range(5):
# Generate random input data
input_tensor = torch.randn(2, 3, 32, 32)
output = model(input_tensor)
# Print HookManager information
print(hook_manager)
print("Current keys:", hook_manager.get_keys()) # Get all registered hook names
# Get intermediate results (feature maps)
print("\nconv1:", hook_manager.get_features('conv1')[0].shape) # Feature map of conv1
print(" fc:", hook_manager.get_features('CustomName')[0].shape) # Feature map of fc
# Get all feature maps
all_features = hook_manager.get_all()
# Concatenate feature maps for each layer (may cause memory overflow if data is too large)
concatenated_features = {key: torch.cat(features, dim=0) for key, features in all_features.items()}
# Compute mean and standard deviation
stats = {key: (torch.mean(value), torch.std(value)) for key, value in concatenated_features.items()}
# Print results
print("\nMean and Std of features:")
for key, (mean, std) in stats.items():
print(f"Layer: {key}, Mean: {mean.item():.4f}, Std: {std.item():.4f}")
# Clear hooks and features
hook_manager.clear_hooks()
hook_manager.clear_features()
Example Output:
Model: MyModel | Total Parameters: 144.46 K
Layer Name Feature Count Feature Shape
--------------------------------------------------------------------------------
conv1 5 (2, 16, 30, 30)
ReLU_0 5 (2, 16, 30, 30)
CustomName 5 (2, 10)
--------------------------------------------------------------------------------
Current keys: ['conv1', 'ReLU_0', 'CustomName']
conv1: torch.Size([2, 16, 30, 30])
fc: torch.Size([2, 10])
Mean and Std of features:
Layer: conv1, Mean: -0.0460, Std: 0.5873
Layer: ReLU_0, Mean: 0.2116, Std: 0.3276
Layer: CustomName, Mean: -0.0596, Std: 0.2248
License
MIT License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchhook-0.1.7.tar.gz
(9.0 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchhook-0.1.7.tar.gz.
File metadata
- Download URL: torchhook-0.1.7.tar.gz
- Upload date:
- Size: 9.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1b46dbda0d12a7ce0892a0321d951cfff81d3ade9b60938e7b064a42488b51e3
|
|
| MD5 |
206e0738bb76f8602170b1b2a4f9fb1f
|
|
| BLAKE2b-256 |
5b7b519c2a2f5f8f30ae8cc5b261cae506c25f5232a154a5f9d26a9e9c8d86a6
|
File details
Details for the file torchhook-0.1.7-py3-none-any.whl.
File metadata
- Download URL: torchhook-0.1.7-py3-none-any.whl
- Upload date:
- Size: 8.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
507c46c8259782b70b27e619fdc98874c45805e92a3b202cde72251393daed7d
|
|
| MD5 |
447894a15856020831d959c782b5bfcb
|
|
| BLAKE2b-256 |
67e0f19fd84613d421f9bdd1f32e08e55101d53f3ac4cc2393c5236abfd7af41
|