Skip to main content

Pytorch feature extraction made simple

Project description

torchextractor: PyTorch Intermediate Feature Extraction

Introduction

Too many times some model definitions get remorselessly copy-pasted just because the forward function does not return what the person expects. You provide module names and torchextractor takes care of the extraction for you.It's never been easier to extract feature, add an extra loss or plug another head to a network. Ler us know what amazing things you build with torchextractor!

Installation

pip install git+https://github.com/antoinebrl/torchextractor.git

Requirements:

  • Python >= 3.6+
  • torch >= 1.4.0

Usage

import torch
import torchvision
import torchextractor as tx

model = torchvision.models.resnet18(pretrained=True)
model = tx.Extractor(model, ["layer1", "layer2", "layer3", "layer4"])
dummy_input = torch.rand(7, 3, 224, 224)
model_output, features = model(dummy_input)
feature_shapes = {name: f.shape for name, f in features.items()}
print(feature_shapes)

# {
#   'layer1': torch.Size([1, 64, 56, 56]),
#   'layer2': torch.Size([1, 128, 28, 28]),
#   'layer3': torch.Size([1, 256, 14, 14]),
#   'layer4': torch.Size([1, 512, 7, 7]),
# }

See more examples Binder Open In Colab

FAQ

• How do I know the names of the modules?

You can print all module names like this:

for name, module in model.named_modules():
    print(name)

• Why do some operations not get listed?

It is not possible to add hooks if operations are not defined as modules. Therefore, F.relu cannot be captured but nn.Relu() can.

• How can I avoid listing all relevant modules?

You can specify a custom filtering function to hook the relevant modules:

# Hook everything !
module_filter_fn = lambda module, name: True

# Capture of all modules inside first layer
module_filter_fn = lambda module, name: name.startswith("layer1")

# Focus on all convolutions
module_filter_fn = lambda module, name: isinstance(module, torch.nn.Conv2d)

model = tx.Extractor(model, module_filter_fn=module_filter_fn)

• Is it compatible with ONNX?

tx.Extractor is compatible with ONNX! This means you can also access intermediate features maps after the export.

Pro-tip: name the output nodes by using output_names when calling torch.onnx.export.

• Is it compatible with TorchScript?

Bad news, TorchScript cannot take variable number of arguments and keyword-only arguments. Good news, there is a workaround! The solution is to overwrite the forward function of tx.Extractor to replicate the interface of the model.

import torch
import torchvision
import torchextractor as tx

class MyExtractor(tx.Extractor):
    def forward(self, x1, x2, x3):
        # Assuming the model takes x1, x2 and x3 as input
        output = self.model(x1, x2, x3)
        return output, self.feature_maps

model = torchvision.models.resnet18(pretrained=True)
model = MyExtractor(model, ["layer1", "layer2", "layer3", "layer4"])
model_traced = torch.jit.script(model)

Contributing

All feedbacks and contributions are welcomed. Feel free to report an issue or to create a pull request!

If you want to get hands-on:

  1. (Fork and) clone the repo.
  2. Create a virtual environment: virtualenv -p python3 .venv && source .venv/bin/activate
  3. Install dependencies: pip install -r requirements.txt && pip install -r requirements-dev.txt
  4. Hook auto-formatting tools: pre-commit install
  5. Hack as much as you want!
  6. Run tests: python -m unittest discover -vs ./tests/
  7. Share your work and create a pull request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchextractor-0.2.0.tar.gz (5.6 kB view details)

Uploaded Source

Built Distribution

torchextractor-0.2.0-py3-none-any.whl (10.3 kB view details)

Uploaded Python 3

File details

Details for the file torchextractor-0.2.0.tar.gz.

File metadata

  • Download URL: torchextractor-0.2.0.tar.gz
  • Upload date:
  • Size: 5.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/40.6.2 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.6.13

File hashes

Hashes for torchextractor-0.2.0.tar.gz
Algorithm Hash digest
SHA256 ab9ce331bd5f9814bf9c2be7aa50a0b1c076083f00cf1fb14d93252c7827a226
MD5 c71329094f9067371804b8ff510dc24f
BLAKE2b-256 e4f6fb62a29962d3ef32b8d33255597d97c0fc04cd9f5324a0585e366b30a7ca

See more details on using hashes here.

File details

Details for the file torchextractor-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: torchextractor-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 10.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/40.6.2 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.6.13

File hashes

Hashes for torchextractor-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c15da5dc63ea6a5cbdadc7e1e63354d9838ef8d479efe8cbe2d22f745fb5a3b6
MD5 31ea271379638e30f66878e631cfc6b7
BLAKE2b-256 1eb12be31915b38fe1e4ec6108d0429e0b701ece7a95c42ab40b575cf3263974

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page