Skip to main content

TouchHook: A PyTorch hook management library

Project description

TorchHook Logo

TorchHook

PyPI version License: MIT Downloads Python Version

English Blog | 中文博客 | 中文文档

TorchHook is a lightweight, easy-to-use Python library designed to simplify the process of extracting intermediate features from PyTorch models. It provides a clean API to manage PyTorch hooks for capturing layer outputs without the boilerplate code.

Key Features

  • Easy Hook Registration: Quickly register hooks for desired model layers by name or object.
  • Flexible Feature Extraction: Retrieve captured features easily.
  • Customizable: Define custom hook logic or output transformations.
  • Resource Management: Automatic cleanup of registered hooks.

Installation

pip install torchhook

Or install from the local source:

git clone https://github.com/zzaiyan/TorchHook.git
cd TorchHook
pip install .

Quick Start

import torch
import torchvision.models as models
from torchhook import HookManager

# 1. Load your model
model = models.resnet18()
model.eval()

# 2. Initialize HookManager
hook_manager = HookManager(model, max_size=1) # Keep only the latest feature per hook

# 3. Register layers
hook_manager.add(layer_name='conv1')
hook_manager.add(layer_name='layer4.1.relu')
hook_manager.add(layer_name='fully_connected', layer=model.fc) # Optional: pass layer object

# 4. Forward pass
dummy_input = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    output = model(dummy_input)

# 5. Get features
features_conv1 = hook_manager.get('conv1')
features_relu = hook_manager.get('layer4.1.relu')
all_features = hook_manager.get_all() # Get all features as a dict

print(f"Conv1 feature shape: {features_conv1[0].shape}")
print(f"Layer 4.1 ReLU feature shape: {features_relu[0].shape}")

# 6. Summary (Optional)
hook_manager.summary()

# 7. Clean up hooks (Important!)
hook_manager.clear_hooks()

For advanced usage like custom hooks and output transformations, please refer to the blog posts: English | 中文

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchhook-0.2.5.tar.gz (10.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchhook-0.2.5-py3-none-any.whl (10.1 kB view details)

Uploaded Python 3

File details

Details for the file torchhook-0.2.5.tar.gz.

File metadata

  • Download URL: torchhook-0.2.5.tar.gz
  • Upload date:
  • Size: 10.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.17

File hashes

Hashes for torchhook-0.2.5.tar.gz
Algorithm Hash digest
SHA256 079f74952ba0da81b94b04a779c749211587d65ba7aba8ff4c6364697e538b0b
MD5 2796b0b5cd3725b925ffeb6bd981998c
BLAKE2b-256 fb08f8c96aae342751928c71326282f7a5f2162689abf41a943992876cf8a4e6

See more details on using hashes here.

File details

Details for the file torchhook-0.2.5-py3-none-any.whl.

File metadata

  • Download URL: torchhook-0.2.5-py3-none-any.whl
  • Upload date:
  • Size: 10.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.17

File hashes

Hashes for torchhook-0.2.5-py3-none-any.whl
Algorithm Hash digest
SHA256 f988e76e8c5ede3b911b9f26bcf40af7e00758d96b2eec48e7eedae331ba770f
MD5 a289014557dc53156074151e2f744763
BLAKE2b-256 f57e64676daaf3f8eb1a50e8f54d07193340adff1181ca85fb295558532a276a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page