Skip to main content

Pytorch feature extraction made simple

Project description

torchextractor: PyTorch Intermediate Feature Extraction

PyPI - Python Version PyPI Read the Docs Upload Python Package GitHub

Introduction

Too many times some model definitions get remorselessly copy-pasted just because the forward function does not return what the person expects. You provide module names and torchextractor takes care of the extraction for you.It's never been easier to extract feature, add an extra loss or plug another head to a network. Ler us know what amazing things you build with torchextractor!

Installation

pip install torchextractor  # stable
pip install git+https://github.com/antoinebrl/torchextractor.git  # latest

Requirements:

  • Python >= 3.6+
  • torch >= 1.4.0

Usage

import torch
import torchvision
import torchextractor as tx

model = torchvision.models.resnet18(pretrained=True)
model = tx.Extractor(model, ["layer1", "layer2", "layer3", "layer4"])
dummy_input = torch.rand(7, 3, 224, 224)
model_output, features = model(dummy_input)
feature_shapes = {name: f.shape for name, f in features.items()}
print(feature_shapes)

# {
#   'layer1': torch.Size([1, 64, 56, 56]),
#   'layer2': torch.Size([1, 128, 28, 28]),
#   'layer3': torch.Size([1, 256, 14, 14]),
#   'layer4': torch.Size([1, 512, 7, 7]),
# }

See more examples Binder Open In Colab

Read the documentation

FAQ

• How do I know the names of the modules?

You can print all module names like this:

tx.list_module_names(model)

# OR

for name, module in model.named_modules():
    print(name)

• Why do some operations not get listed?

It is not possible to add hooks if operations are not defined as modules. Therefore, F.relu cannot be captured but nn.Relu() can.

• How can I avoid listing all relevant modules?

You can specify a custom filtering function to hook the relevant modules:

# Hook everything !
module_filter_fn = lambda module, name: True

# Capture of all modules inside first layer
module_filter_fn = lambda module, name: name.startswith("layer1")

# Focus on all convolutions
module_filter_fn = lambda module, name: isinstance(module, torch.nn.Conv2d)

model = tx.Extractor(model, module_filter_fn=module_filter_fn)

• Is it compatible with ONNX?

tx.Extractor is compatible with ONNX! This means you can also access intermediate features maps after the export.

Pro-tip: name the output nodes by using output_names when calling torch.onnx.export.

• Is it compatible with TorchScript?

Bad news, TorchScript cannot take variable number of arguments and keyword-only arguments.

Good news, there is a workaround! The solution is to overwrite the forward function of tx.Extractor to replicate the interface of the model.

import torch
import torchvision
import torchextractor as tx

class MyExtractor(tx.Extractor):
    def forward(self, x1, x2, x3):
        # Assuming the model takes x1, x2 and x3 as input
        output = self.model(x1, x2, x3)
        return output, self.feature_maps

model = torchvision.models.resnet18(pretrained=True)
model = MyExtractor(model, ["layer1", "layer2", "layer3", "layer4"])
model_traced = torch.jit.script(model)

• "One more thing!" :wink: By default we capture the latest output of the relevant modules, but you can specify your own custom operations.

For example, to accumulate features over 10 forward passes you can do the following:

import torch
import torchvision
import torchextractor as tx

model = torchvision.models.resnet18(pretrained=True)

def capture_fn(module, input, output, module_name, feature_maps):
    if module_name not in feature_maps:
        feature_maps[module_name] = []
    feature_maps[module_name].append(output)

extractor = tx.Extractor(model, ["layer3", "layer4"], capture_fn=capture_fn)

for i in range(20):
    for i in range(10):
        x = torch.rand(7, 3, 224, 224)
        model(x)
    feature_maps = extractor.collect()

    # Do your stuffs here

    # Discard collected elements
    extractor.clear_placeholder()

Contributing

All feedbacks and contributions are welcomed. Feel free to report an issue or to create a pull request!

If you want to get hands-on:

  1. (Fork and) clone the repo.
  2. Create a virtual environment: virtualenv -p python3 .venv && source .venv/bin/activate
  3. Install dependencies: pip install -r requirements.txt && pip install -r requirements-dev.txt
  4. Hook auto-formatting tools: pre-commit install
  5. Hack as much as you want!
  6. Run tests: python -m unittest discover -vs ./tests/
  7. Share your work and create a pull request.

To Build documentation:

cd docs
pip install requirements.txt
make html

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchextractor-0.3.0.tar.gz (6.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchextractor-0.3.0-py3-none-any.whl (10.8 kB view details)

Uploaded Python 3

File details

Details for the file torchextractor-0.3.0.tar.gz.

File metadata

  • Download URL: torchextractor-0.3.0.tar.gz
  • Upload date:
  • Size: 6.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/40.6.2 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.6.13

File hashes

Hashes for torchextractor-0.3.0.tar.gz
Algorithm Hash digest
SHA256 fd1bbc1f32c7db25aaee7e3c0fff7abbff48f22bf43acae95bb3e55efd0282f3
MD5 35b55d4bf448822c1139b447863cac53
BLAKE2b-256 6b079b4811b9571a35a021beae83d8abee2e669ad37056584cf24408de7c3ea0

See more details on using hashes here.

File details

Details for the file torchextractor-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: torchextractor-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 10.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/40.6.2 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.6.13

File hashes

Hashes for torchextractor-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1bfd90eea59f69e375240326304d0091f77a0e536b997d3c64aba564890d4fa1
MD5 f036ed73387b252fe70c8daaba81410f
BLAKE2b-256 cc94f14591882d0459a626d6aa8ed3699b08e6b79192c26cae87cbd6081cb835

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page