Pytorch feature extraction made simple
Project description
torchextractor
: PyTorch Intermediate Feature Extraction
Introduction
Too many times some model definitions get remorselessly copy-pasted just because the
forward
function does not return what the person expects. You provide module names
and torchextractor
takes care of the extraction for you.It's never been easier to
extract feature, add an extra loss or plug another head to a network.
Ler us know what amazing things you build with torchextractor
!
Installation
pip install git+https://github.com/antoinebrl/torchextractor.git
Requirements:
- Python >= 3.6+
- torch >= 1.4.0
Usage
import torch
import torchvision
import torchextractor as tx
model = torchvision.models.resnet18(pretrained=True)
model = tx.Extractor(model, ["layer1", "layer2", "layer3", "layer4"])
dummy_input = torch.rand(7, 3, 224, 224)
model_output, features = model(dummy_input)
feature_shapes = {name: f.shape for name, f in features.items()}
print(feature_shapes)
# {
# 'layer1': torch.Size([1, 64, 56, 56]),
# 'layer2': torch.Size([1, 128, 28, 28]),
# 'layer3': torch.Size([1, 256, 14, 14]),
# 'layer4': torch.Size([1, 512, 7, 7]),
# }
FAQ
• How do I know the names of the modules?
You can print all module names like this:
for name, module in model.named_modules():
print(name)
• Why do some operations not get listed?
It is not possible to add hooks if operations are not defined as modules.
Therefore, F.relu
cannot be captured but nn.Relu()
can.
• How can I avoid listing all relevant modules?
You can specify a custom filtering function to hook the relevant modules:
# Hook everything !
module_filter_fn = lambda module, name: True
# Capture of all modules inside first layer
module_filter_fn = lambda module, name: name.startswith("layer1")
# Focus on all convolutions
module_filter_fn = lambda module, name: isinstance(module, torch.nn.Conv2d)
model = tx.Extractor(model, module_filter_fn=module_filter_fn)
• Is it compatible with ONNX?
tx.Extractor
is compatible with ONNX! This means you can also access intermediate features maps after the export.
Pro-tip: name the output nodes by using output_names
when calling torch.onnx.export
.
• Is it compatible with TorchScript?
Bad news, TorchScript cannot take variable number of arguments and keyword-only arguments.
Good news, there is a workaround! The solution is to overwrite the forward
function
of tx.Extractor
to replicate the interface of the model.
import torch
import torchvision
import torchextractor as tx
class MyExtractor(tx.Extractor):
def forward(self, x1, x2, x3):
# Assuming the model takes x1, x2 and x3 as input
output = self.model(x1, x2, x3)
return output, self.feature_maps
model = torchvision.models.resnet18(pretrained=True)
model = MyExtractor(model, ["layer1", "layer2", "layer3", "layer4"])
model_traced = torch.jit.script(model)
Contributing
All feedbacks and contributions are welcomed. Feel free to report an issue or to create a pull request!
If you want to get hands-on:
- (Fork and) clone the repo.
- Create a virtual environment:
virtualenv -p python3 .venv && source .venv/bin/activate
- Install dependencies:
pip install -r requirements.txt && pip install -r requirements-dev.txt
- Hook auto-formatting tools:
pre-commit install
- Hack as much as you want!
- Run tests:
python -m unittest discover -vs ./tests/
- Share your work and create a pull request.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torchextractor-0.2.0.tar.gz
.
File metadata
- Download URL: torchextractor-0.2.0.tar.gz
- Upload date:
- Size: 5.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/40.6.2 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.6.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ab9ce331bd5f9814bf9c2be7aa50a0b1c076083f00cf1fb14d93252c7827a226 |
|
MD5 | c71329094f9067371804b8ff510dc24f |
|
BLAKE2b-256 | e4f6fb62a29962d3ef32b8d33255597d97c0fc04cd9f5324a0585e366b30a7ca |
File details
Details for the file torchextractor-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: torchextractor-0.2.0-py3-none-any.whl
- Upload date:
- Size: 10.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/40.6.2 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.6.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c15da5dc63ea6a5cbdadc7e1e63354d9838ef8d479efe8cbe2d22f745fb5a3b6 |
|
MD5 | 31ea271379638e30f66878e631cfc6b7 |
|
BLAKE2b-256 | 1eb12be31915b38fe1e4ec6108d0429e0b701ece7a95c42ab40b575cf3263974 |