Skip to main content

Useful utility functions and wrappers for hooking onto layers within PyTorch models for feature extraction.

Project description

Ruff image image image lint test

torch-featurelayer

🧠 Simple utility functions and wrappers for hooking onto layers within PyTorch models for feature extraction.

[!TIP] This library is intended to be a simplified and well-documented implementation for extracting a PyTorch model's intermediate layer output(s). For a more sophisticated and complete implementation, either consider using torchvision.models.feature_extraction, or check the official torch.fx.

Install

pip install torch-featurelayer

Usage

Imports:

import torch
from torchvision.models import vgg11
from torch_featurelayer import FeatureLayer

Load a pretrained VGG-11 model:

model = vgg11(weights='DEFAULT').eval()

Hook onto layer features.15 of the model:

layer_path = 'features.15'
hooked_model = FeatureLayer(model, layer_path)

Forward pass an input tensor through the model:

x = torch.randn(1, 3, 224, 224)
feature_output, output = hooked_model(x)

feature_output is the output of layer features.15. Print the output shape:

print(f'Feature layer output shape: {feature_output.shape}')  # [1, 512, 14, 14]
print(f'Model output shape: {output.shape}')  # [1, 1000]

Check the examples directory for more.

API

torch_featurelayer.FeatureLayer

The FeatureLayer class wraps a model and provides a hook to access the output of a specific feature layer.

  • __init__(self, model: torch.nn.Module, feature_layer_path: str)

    Initializes the FeatureLayer instance.

    • model: The model containing the feature layer.
    • feature_layer_path: The path to the feature layer in the model.
  • __call__(self, *args: Any, **kwargs: Any) -> tuple[torch.Tensor | None, torch.Tensor]

    Performs a forward pass through the model and updates the hooked feature layer.

    • *args: Variable length argument list.
    • **kwargs: Arbitrary keyword arguments.

    Returns a tuple containing the feature layer output and the model output.

torch_featurelayer.FeatureLayers

The FeatureLayers class wraps a model and provides hooks to access the output of multiple feature layers.

  • __init__(self, model: torch.nn.Module, feature_layer_paths: list[str])

    Initializes the FeatureLayers instance.

    • model: The model containing the feature layers.
    • feature_layer_paths: A list of paths to the feature layers in the model.
  • __call__(self, *args: Any, **kwargs: Any) -> tuple[dict[str, torch.Tensor | None], torch.Tensor]

    Performs a forward pass through the model and updates the hooked feature layers.

    • *args: Variable length argument list.
    • **kwargs: Arbitrary keyword arguments.

    Returns a tuple containing the feature layer outputs and the model output.

torch_featurelayer.get_layer_candidates(module: torch.nn.Module, max_depth: int = 1) -> Generator[str, None, None]

The get_layer_candidates function returns a generator of layer paths for a given model up to a specified depth.

  • model: The model to get layer paths from.
  • max_depth: The maximum depth to traverse in the model's layers.

Returns a generator of layer paths.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_featurelayer-0.1.2.tar.gz (8.2 kB view details)

Uploaded Source

Built Distribution

torch_featurelayer-0.1.2-py3-none-any.whl (7.8 kB view details)

Uploaded Python 3

File details

Details for the file torch_featurelayer-0.1.2.tar.gz.

File metadata

  • Download URL: torch_featurelayer-0.1.2.tar.gz
  • Upload date:
  • Size: 8.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.10

File hashes

Hashes for torch_featurelayer-0.1.2.tar.gz
Algorithm Hash digest
SHA256 42c6cf350febc2c7b5b172d83d1d0624645893e01ab0ceee7a21c8469c0a9052
MD5 89ce73fd7ad656cfc93328c4ced57135
BLAKE2b-256 7c0f2e0b893125ef4196a59cf7fa7aa139265687ce09e06a3c5d1fe6a02347c7

See more details on using hashes here.

File details

Details for the file torch_featurelayer-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_featurelayer-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 3638311e0b9ff972ecb184b0bcb39b82721972da93d922ea36ef65cfc43e8a2e
MD5 11109bc185a7b4dfdeafe464d7d2b015
BLAKE2b-256 c2b2118c64dbd60e13a87116bf68a7a6d4ff09ed51fd547fa64a51878c911a60

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page