TouchHook: A PyTorch hook management library
Project description
TorchHook
English Blog | 中文博客 | 中文文档
TorchHook is a lightweight, easy-to-use Python library designed to simplify the process of extracting intermediate features from PyTorch models. It provides a clean API to manage PyTorch hooks for capturing layer outputs without the boilerplate code.
Key Features
- Easy Hook Registration: Quickly register hooks for desired model layers by name or object.
- Flexible Feature Extraction: Retrieve captured features easily.
- Highly Customizable: Define custom hook logic or output transformations.
- Resource Management: Automatic cleanup of registered hooks.
Installation
pip install torchhook
Or install from the local source:
git clone https://github.com/zzaiyan/TorchHook.git
cd TorchHook
pip install .
Quick Start
import torch
import torchvision.models as models
from torchhook import HookManager
# 1. Load your model
model = models.resnet18()
model.eval()
# 2. Initialize HookManager
hook_manager = HookManager(model, max_size=1) # Keep only the latest feature per hook
# 3. Register layers
hook_manager.add(layer_name='conv1')
hook_manager.add(layer_name='layer4.1.relu')
hook_manager.add(layer_name='fully_connected', layer=model.fc) # Optional: pass layer object
# 4. Forward pass
dummy_input = torch.randn(1, 3, 224, 224)
with torch.no_grad():
output = model(dummy_input)
# 5. Get features
features_conv1 = hook_manager.get('conv1')
features_relu = hook_manager.get('layer4.1.relu')
all_features = hook_manager.get_all() # Get all features as a dict
print(f"Conv1 feature shape: {features_conv1[0].shape}")
print(f"Layer 4.1 ReLU feature shape: {features_relu[0].shape}")
# 6. Summary (Optional)
hook_manager.summary()
# 7. Clean up hooks (Important!)
hook_manager.clear_hooks()
For advanced usage like custom hooks and output transformations, please refer to the blog posts: English | 中文
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file TorchHook-0.2.7.tar.gz.
File metadata
- Download URL: TorchHook-0.2.7.tar.gz
- Upload date:
- Size: 9.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
51d578f8704c95f234c89fe6d2eafc62673d52212c0c82eb14b512bcdf446e4a
|
|
| MD5 |
d183345a36c86111b8e0b1ff9d97f1be
|
|
| BLAKE2b-256 |
ccc2b7ccb3b95c8c65604ab8bae95b0221823aa005fffaa3fd9859be04b62e7b
|
File details
Details for the file TorchHook-0.2.7-py3-none-any.whl.
File metadata
- Download URL: TorchHook-0.2.7-py3-none-any.whl
- Upload date:
- Size: 10.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
71c0e3ebf6e9ac377f1cacebeb1244f3071a6d85aa72dc06892af99f0631ab0d
|
|
| MD5 |
2811f81b9a14e3b3578786d912f12be3
|
|
| BLAKE2b-256 |
75ba17277d7a3934fbcb97eac8870aa153ce9957086e0562c84e4c49d13511db
|