Skip to main content

TouchHook: A PyTorch hook management library

Project description

TorchHook Logo

TorchHook

TorchHook is a library for managing PyTorch model hooks, providing convenient interfaces to capture feature maps and debug models.

Features

  • Easy Hook Management: Simplify the process of registering and managing hooks in PyTorch models.
  • Feature Map Extraction: Capture intermediate feature maps for analysis and debugging.
  • Customizable: Support for custom hook names and flexible usage.

Installation

pip install torchhook

Usage Example

import torch
import torch.nn as nn
from torchhook import HookManager

# Define a simple model
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(16 * 30 * 30, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Initialize model and HookManager
model = MyModel()
hook_manager = HookManager(model)

# Register hooks using layer_name (recommended for simplicity)
hook_manager.register_forward_hook(layer_name="conv1")

# Register hooks using layer object (automatically named as: ClassName+Index)
hook_manager.register_forward_hook(layer=model.relu)

# Register hooks with a custom name (useful for distinguishing hooks when debugging)
hook_manager.register_forward_hook('CustomName', layer=model.fc)

# Run the model
for _ in range(5):
    # Generate random input data
    input_tensor = torch.randn(2, 3, 32, 32)
    output = model(input_tensor)

# Print HookManager information
print(hook_manager)
print("Current keys:", hook_manager.get_keys())  # Get all registered hook names

# Get intermediate results (feature maps)
print("\nconv1:", hook_manager.get_features('conv1')[0].shape)  # Feature map of conv1
print("   fc:", hook_manager.get_features('CustomName')[0].shape)  # Feature map of fc

# Get all feature maps
all_features = hook_manager.get_all()

# Concatenate feature maps for each layer (may cause memory overflow if data is too large)
concatenated_features = {key: torch.cat(features, dim=0) for key, features in all_features.items()}

# Compute mean and standard deviation
stats = {key: (torch.mean(value), torch.std(value)) for key, value in concatenated_features.items()}

# Print results
print("\nMean and Std of features:")
for key, (mean, std) in stats.items():
    print(f"Layer: {key}, Mean: {mean.item():.4f}, Std: {std.item():.4f}")

# Clear hooks and features
hook_manager.clear_hooks()
hook_manager.clear_features()

Example Output:

Model: MyModel | Total Parameters: 144.46 K
Layer Name                    Feature Count       Feature Shape                 
--------------------------------------------------------------------------------
conv1                         5                   (2, 16, 30, 30)               
ReLU_0                        5                   (2, 16, 30, 30)               
CustomName                    5                   (2, 10)                       
--------------------------------------------------------------------------------
Current keys: ['conv1', 'ReLU_0', 'CustomName']

conv1: torch.Size([2, 16, 30, 30])
   fc: torch.Size([2, 10])

Mean and Std of features:
Layer: conv1, Mean: -0.0460, Std: 0.5873
Layer: ReLU_0, Mean: 0.2116, Std: 0.3276
Layer: CustomName, Mean: -0.0596, Std: 0.2248

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchhook-0.2.0.tar.gz (9.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchhook-0.2.0-py3-none-any.whl (9.4 kB view details)

Uploaded Python 3

File details

Details for the file torchhook-0.2.0.tar.gz.

File metadata

  • Download URL: torchhook-0.2.0.tar.gz
  • Upload date:
  • Size: 9.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.17

File hashes

Hashes for torchhook-0.2.0.tar.gz
Algorithm Hash digest
SHA256 e8d2e16fa44281b3f4cfd14d999088eb2d2db65283ffb48ac8985504418f998c
MD5 47b10bc3f9b89aa3de87c8214da70b53
BLAKE2b-256 d2345db2fe1c5c78556b4025cfa5180c5a27d9c4031bd75ffecf2a6cd435af36

See more details on using hashes here.

File details

Details for the file torchhook-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: torchhook-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 9.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.17

File hashes

Hashes for torchhook-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 deaeb5d50737e910edd176427e263b75e1282382040358af5b72ae1fb5642cc7
MD5 ca0d706da0caa6dda33d7666e4b95eea
BLAKE2b-256 cffe8cdef0b2b609a0ceaedbc49ad7a802b4c19f04fa5dd3946e391c0752421b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page