TouchHook: A PyTorch hook management library
Project description
TorchHook
TorchHook is a library for managing PyTorch model hooks, providing convenient interfaces to capture feature maps and debug models.
Features
- Easy Hook Management: Simplify the process of registering and managing hooks in PyTorch models.
- Feature Map Extraction: Capture intermediate feature maps for analysis and debugging.
- Customizable: Support for custom hook names and flexible usage.
Installation
pip install torchhook
Usage Example
import torch
import torch.nn as nn
from torchhook import HookManager
# Define a simple model
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3)
self.relu = nn.ReLU()
self.fc = nn.Linear(16 * 30 * 30, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# Initialize model and HookManager
model = MyModel()
hook_manager = HookManager(model)
# Register hooks using layer_name (recommended for simplicity)
hook_manager.register_forward_hook(layer_name="conv1")
# Register hooks using layer object (automatically named as: ClassName+Index)
hook_manager.register_forward_hook(layer=model.relu)
# Register hooks with a custom name (useful for distinguishing hooks when debugging)
hook_manager.register_forward_hook('CustomName', layer=model.fc)
# Run the model
for _ in range(5):
# Generate random input data
input_tensor = torch.randn(2, 3, 32, 32)
output = model(input_tensor)
# Print HookManager information
print(hook_manager)
print("Current keys:", hook_manager.get_keys()) # Get all registered hook names
# Get intermediate results (feature maps)
print("\nconv1:", hook_manager.get_features('conv1')[0].shape) # Feature map of conv1
print(" fc:", hook_manager.get_features('CustomName')[0].shape) # Feature map of fc
# Get all feature maps
all_features = hook_manager.get_all()
# Concatenate feature maps for each layer (may cause memory overflow if data is too large)
concatenated_features = {key: torch.cat(features, dim=0) for key, features in all_features.items()}
# Compute mean and standard deviation
stats = {key: (torch.mean(value), torch.std(value)) for key, value in concatenated_features.items()}
# Print results
print("\nMean and Std of features:")
for key, (mean, std) in stats.items():
print(f"Layer: {key}, Mean: {mean.item():.4f}, Std: {std.item():.4f}")
# Clear hooks and features
hook_manager.clear_hooks()
hook_manager.clear_features()
Example Output:
Model: MyModel | Total Parameters: 144.46 K
Layer Name Feature Count Feature Shape
--------------------------------------------------------------------------------
conv1 5 (2, 16, 30, 30)
ReLU_0 5 (2, 16, 30, 30)
CustomName 5 (2, 10)
--------------------------------------------------------------------------------
Current keys: ['conv1', 'ReLU_0', 'CustomName']
conv1: torch.Size([2, 16, 30, 30])
fc: torch.Size([2, 10])
Mean and Std of features:
Layer: conv1, Mean: -0.0460, Std: 0.5873
Layer: ReLU_0, Mean: 0.2116, Std: 0.3276
Layer: CustomName, Mean: -0.0596, Std: 0.2248
License
MIT License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchhook-0.1.6.tar.gz
(8.9 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchhook-0.1.6.tar.gz.
File metadata
- Download URL: torchhook-0.1.6.tar.gz
- Upload date:
- Size: 8.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
78c14486204b915e7730f2d7095a9b3a0012adc73b8a1dcaec800b53484cef2d
|
|
| MD5 |
36083c978e24aaa1aa70520c4e96d7e5
|
|
| BLAKE2b-256 |
2151b229315a5537c81ddc0b86a3fe5172f57d9245625813829b871da1b98081
|
File details
Details for the file torchhook-0.1.6-py3-none-any.whl.
File metadata
- Download URL: torchhook-0.1.6-py3-none-any.whl
- Upload date:
- Size: 8.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
65830679993250583a18b799f1d9f21a78459ec8209d4b625e906f8d4fd455bf
|
|
| MD5 |
12fd843fba5d4e545c1d88b6fd47df6f
|
|
| BLAKE2b-256 |
9a7df1544abf71507c298c49b8b93d79531cc138bfd4a2dd06fe704ac9f81a69
|