Skip to main content

Useful utility functions and wrappers for hooking onto layers within PyTorch models for feature extraction.

Project description

Ruff image image image lint

torch-featurelayer

🧠 Simple utility functions and wrappers for hooking onto layers within PyTorch models for feature extraction.

[!TIP] This library is intended to be a simplified and well-documented implementation for extracting a PyTorch model's intermediate layer output(s). For a more sophisticated and complete implementation, either consider using torchvision.models.feature_extraction, or check the official torch.fx.

Usage

import torch
from torchvision.models import vgg11

from torch_featurelayer import FeatureLayer

# Load a pretrained VGG-11 model
model = vgg11(weights='DEFAULT').eval()

# Hook onto layer `features.15` of the model
layer_path = 'features.15'
hooked_model = FeatureLayer(model, layer_path)

# Forward pass an input tensor through the model
x = torch.randn(1, 3, 224, 224)
feature_output, output = hooked_model(x)

# Print the output shape
print(f'Feature layer output shape: {feature_output.shape}')  # [1, 512, 14, 14]
print(f'Model output shape: {output.shape}')  # [1, 1000]

Check the examples directory for more.

API

torch_featurelayer.FeatureLayer(model: torch.nn.Module, feature_layer_path: str)

torch_featurelayer.FeatureLayers(model: torch.nn.Module, feature_layer_paths: list[str])

torch_featurelayer.get_layer_candidates(module: nn.Module, max_depth: int = 1)

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_featurelayer-0.1.1.tar.gz (6.1 kB view hashes)

Uploaded Source

Built Distribution

torch_featurelayer-0.1.1-py3-none-any.whl (7.2 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page