Useful utility functions and wrappers for hooking onto layers within PyTorch models for feature extraction.
Project description
torch-featurelayer
🧠 Simple utility functions and wrappers for hooking onto layers within PyTorch models for feature extraction.
[!TIP] This library is intended to be a simplified and well-documented implementation for extracting a PyTorch model's intermediate layer output(s). For a more sophisticated and complete implementation, either consider using
torchvision.models.feature_extraction, or check the officialtorch.fx.
Install
pip install torch-featurelayer
Usage
Imports:
import torch
from torchvision.models import vgg11
from torch_featurelayer import FeatureLayer
Load a pretrained VGG-11 model:
model = vgg11(weights='DEFAULT').eval()
Hook onto layer features.15 of the model:
layer_path = 'features.15'
hooked_model = FeatureLayer(model, layer_path)
Forward pass an input tensor through the model:
x = torch.randn(1, 3, 224, 224)
feature_output, output = hooked_model(x)
feature_output is the output of layer features.15. Print the output shape:
print(f'Feature layer output shape: {feature_output.shape}') # [1, 512, 14, 14]
print(f'Model output shape: {output.shape}') # [1, 1000]
Check the examples directory for more.
API
torch_featurelayer.FeatureLayer
The FeatureLayer class wraps a model and provides a hook to access the output of a specific feature layer.
-
__init__(self, model: torch.nn.Module, feature_layer_path: str)Initializes the
FeatureLayerinstance.model: The model containing the feature layer.feature_layer_path: The path to the feature layer in the model.
-
__call__(self, *args: Any, **kwargs: Any) -> tuple[torch.Tensor | None, torch.Tensor]Performs a forward pass through the model and updates the hooked feature layer.
*args: Variable length argument list.**kwargs: Arbitrary keyword arguments.
Returns a tuple containing the feature layer output and the model output.
torch_featurelayer.FeatureLayers
The FeatureLayers class wraps a model and provides hooks to access the output of multiple feature layers.
-
__init__(self, model: torch.nn.Module, feature_layer_paths: list[str])Initializes the
FeatureLayersinstance.model: The model containing the feature layers.feature_layer_paths: A list of paths to the feature layers in the model.
-
__call__(self, *args: Any, **kwargs: Any) -> tuple[dict[str, torch.Tensor | None], torch.Tensor]Performs a forward pass through the model and updates the hooked feature layers.
*args: Variable length argument list.**kwargs: Arbitrary keyword arguments.
Returns a tuple containing the feature layer outputs and the model output.
torch_featurelayer.get_layer_candidates(module: torch.nn.Module, max_depth: int = 1) -> Generator[str, None, None]
The get_layer_candidates function returns a generator of layer paths for a given model up to a specified depth.
model: The model to get layer paths from.max_depth: The maximum depth to traverse in the model's layers.
Returns a generator of layer paths.
License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_featurelayer-0.1.2.tar.gz.
File metadata
- Download URL: torch_featurelayer-0.1.2.tar.gz
- Upload date:
- Size: 8.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
42c6cf350febc2c7b5b172d83d1d0624645893e01ab0ceee7a21c8469c0a9052
|
|
| MD5 |
89ce73fd7ad656cfc93328c4ced57135
|
|
| BLAKE2b-256 |
7c0f2e0b893125ef4196a59cf7fa7aa139265687ce09e06a3c5d1fe6a02347c7
|
File details
Details for the file torch_featurelayer-0.1.2-py3-none-any.whl.
File metadata
- Download URL: torch_featurelayer-0.1.2-py3-none-any.whl
- Upload date:
- Size: 7.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3638311e0b9ff972ecb184b0bcb39b82721972da93d922ea36ef65cfc43e8a2e
|
|
| MD5 |
11109bc185a7b4dfdeafe464d7d2b015
|
|
| BLAKE2b-256 |
c2b2118c64dbd60e13a87116bf68a7a6d4ff09ed51fd547fa64a51878c911a60
|