Useful utility functions and wrappers for hooking onto layers within PyTorch models for feature extraction.
Project description
torch-featurelayer
🧠 Simple utility functions and wrappers for hooking onto layers within PyTorch models for feature extraction.
[!TIP] This library is intended to be a simplified and well-documented implementation for extracting a PyTorch model's intermediate layer output(s). For a more sophisticated and complete implementation, either consider using
torchvision.models.feature_extraction
, or check the officialtorch.fx
.
Install
pip install torch-featurelayer
Usage
Imports:
import torch
from torchvision.models import vgg11
from torch_featurelayer import FeatureLayer
Load a pretrained VGG-11 model:
model = vgg11(weights='DEFAULT').eval()
Hook onto layer features.15
of the model:
layer_path = 'features.15'
hooked_model = FeatureLayer(model, layer_path)
Forward pass an input tensor through the model:
x = torch.randn(1, 3, 224, 224)
feature_output, output = hooked_model(x)
feature_output
is the output of layer features.15
. Print the output shape:
print(f'Feature layer output shape: {feature_output.shape}') # [1, 512, 14, 14]
print(f'Model output shape: {output.shape}') # [1, 1000]
Check the examples directory for more.
API
torch_featurelayer.FeatureLayer
The FeatureLayer
class wraps a model and provides a hook to access the output of a specific feature layer.
-
__init__(self, model: torch.nn.Module, feature_layer_path: str)
Initializes the
FeatureLayer
instance.model
: The model containing the feature layer.feature_layer_path
: The path to the feature layer in the model.
-
__call__(self, *args: Any, **kwargs: Any) -> tuple[torch.Tensor | None, torch.Tensor]
Performs a forward pass through the model and updates the hooked feature layer.
*args
: Variable length argument list.**kwargs
: Arbitrary keyword arguments.
Returns a tuple containing the feature layer output and the model output.
torch_featurelayer.FeatureLayers
The FeatureLayers
class wraps a model and provides hooks to access the output of multiple feature layers.
-
__init__(self, model: torch.nn.Module, feature_layer_paths: list[str])
Initializes the
FeatureLayers
instance.model
: The model containing the feature layers.feature_layer_paths
: A list of paths to the feature layers in the model.
-
__call__(self, *args: Any, **kwargs: Any) -> tuple[dict[str, torch.Tensor | None], torch.Tensor]
Performs a forward pass through the model and updates the hooked feature layers.
*args
: Variable length argument list.**kwargs
: Arbitrary keyword arguments.
Returns a tuple containing the feature layer outputs and the model output.
torch_featurelayer.get_layer_candidates(module: torch.nn.Module, max_depth: int = 1) -> Generator[str, None, None]
The get_layer_candidates
function returns a generator of layer paths for a given model up to a specified depth.
model
: The model to get layer paths from.max_depth
: The maximum depth to traverse in the model's layers.
Returns a generator of layer paths.
License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch_featurelayer-0.1.2.tar.gz
.
File metadata
- Download URL: torch_featurelayer-0.1.2.tar.gz
- Upload date:
- Size: 8.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 42c6cf350febc2c7b5b172d83d1d0624645893e01ab0ceee7a21c8469c0a9052 |
|
MD5 | 89ce73fd7ad656cfc93328c4ced57135 |
|
BLAKE2b-256 | 7c0f2e0b893125ef4196a59cf7fa7aa139265687ce09e06a3c5d1fe6a02347c7 |
File details
Details for the file torch_featurelayer-0.1.2-py3-none-any.whl
.
File metadata
- Download URL: torch_featurelayer-0.1.2-py3-none-any.whl
- Upload date:
- Size: 7.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3638311e0b9ff972ecb184b0bcb39b82721972da93d922ea36ef65cfc43e8a2e |
|
MD5 | 11109bc185a7b4dfdeafe464d7d2b015 |
|
BLAKE2b-256 | c2b2118c64dbd60e13a87116bf68a7a6d4ff09ed51fd547fa64a51878c911a60 |