Useful utility functions and wrappers for hooking onto layers within PyTorch models for feature extraction.
Project description
torch-featurelayer
🧠 Simple utility functions and wrappers for hooking onto layers within PyTorch models for feature extraction.
[!TIP] This library is intended to be a simplified and well-documented implementation for extracting a PyTorch model's intermediate layer output(s). For a more sophisticated and complete implementation, either consider using
torchvision.models.feature_extraction
, or check the officialtorch.fx
.
Usage
import torch
from torchvision.models import vgg11
from torch_featurelayer import FeatureLayer
# Load a pretrained VGG-11 model
model = vgg11(weights='DEFAULT').eval()
# Hook onto layer `features.15` of the model
layer_path = 'features.15'
hooked_model = FeatureLayer(model, layer_path)
# Forward pass an input tensor through the model
x = torch.randn(1, 3, 224, 224)
feature_output, output = hooked_model(x)
# Print the output shape
print(f'Feature layer output shape: {feature_output.shape}') # [1, 512, 14, 14]
print(f'Model output shape: {output.shape}') # [1, 1000]
Check the examples directory for more.
API
torch_featurelayer.FeatureLayer(model: torch.nn.Module, feature_layer_path: str)
torch_featurelayer.FeatureLayers(model: torch.nn.Module, feature_layer_paths: list[str])
torch_featurelayer.get_layer_candidates(module: nn.Module, max_depth: int = 1)
License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for torch_featurelayer-0.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 970bd6a58c4854636d42b0d769b25a795ff187d49efd22c63cdc01f38de9a144 |
|
MD5 | ec3e29e70a04c4a2dbf151501192fe3a |
|
BLAKE2b-256 | cdf590af9f8d9258aa03509c31ebbea50fbd962b9c13baadb4d91a3beddb9de3 |