Useful utility functions and wrappers for hooking onto layers within PyTorch models for feature extraction.
Project description
torch-featurelayer
🧠 Simple utility functions and wrappers for hooking onto layers within PyTorch models for feature extraction.
[!TIP] This library is intended to be a simplified and well-documented implementation for extracting a PyTorch model's intermediate layer output(s). For a more sophisticated and complete implementation, either consider using
torchvision.models.feature_extraction
, or check the officialtorch.fx
.
Usage
import torch
from torchvision.models import vgg11
from torch_featurelayer import FeatureLayer
# Load a pretrained VGG-11 model
model = vgg11(weights='DEFAULT').eval()
# Hook onto layer `features.15` of the model
layer_path = 'features.15'
hooked_model = FeatureLayer(model, layer_path)
# Forward pass an input tensor through the model
x = torch.randn(1, 3, 224, 224)
feature_output, output = hooked_model(x)
# Print the output shape
print(f'Feature layer output shape: {feature_output.shape}') # [1, 512, 14, 14]
print(f'Model output shape: {output.shape}') # [1, 1000]
Check the examples directory for more.
API
torch_featurelayer.FeatureLayer(model: torch.nn.Module, feature_layer_path: str)
torch_featurelayer.FeatureLayers(model: torch.nn.Module, feature_layer_paths: list[str])
torch_featurelayer.get_layer_candidates(module: nn.Module, max_depth: int = 1)
License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch_featurelayer-0.1.1.tar.gz
.
File metadata
- Download URL: torch_featurelayer-0.1.1.tar.gz
- Upload date:
- Size: 6.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ce8e0e36051fd299686802790a396770f15000f3a5a83bc02eaafc23fe4b7551 |
|
MD5 | 42be29ab09876d8eda493494cef32081 |
|
BLAKE2b-256 | df78ae8b38a166a50ba18eee8c860b4c27f5f515d202f44d614ac775221dbea8 |
File details
Details for the file torch_featurelayer-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: torch_featurelayer-0.1.1-py3-none-any.whl
- Upload date:
- Size: 7.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 970bd6a58c4854636d42b0d769b25a795ff187d49efd22c63cdc01f38de9a144 |
|
MD5 | ec3e29e70a04c4a2dbf151501192fe3a |
|
BLAKE2b-256 | cdf590af9f8d9258aa03509c31ebbea50fbd962b9c13baadb4d91a3beddb9de3 |