Simple easy to use module to get the intermediate results from chosen submodules
Project description
Simple easy to use module to get the intermediate results from chosen submodules. Supports submodule annidation. Inspired in this but does not assume that submodules are executed sequentially.
Installation
pip install torch_intermediate_layer_getter
Usage
Example
import torch
import torch.nn as nn
from torch_intermediate_layer_getter import IntermediateLayerGetter as MidGetter
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 2)
self.fc2 = nn.Linear(2, 2)
self.nested = nn.Sequential(
nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 3)),
nn.Linear(3, 1),
)
self.interaction_idty = nn.Identity() # Simple trick for operations not performed as modules
def forward(self, x):
x1 = self.fc1(x)
x2 = self.fc2(x)
interaction = x1 * x2
self.interaction_idty(interaction)
x_out = self.nested(interaction)
return x_out
model = Model()
return_layers = {
'fc2': 'fc2',
'nested.0.1': 'nested',
'interaction_idty': 'interaction',
}
mid_getter = MidGetter(model, return_layers=return_layers, keep_output=True)
mid_outputs, model_output = mid_getter(torch.randn(1, 2))
print(model_output)
>> tensor([[0.3219]], grad_fn=<AddmmBackward>)
print(mid_outputs)
>> OrderedDict([('fc2', tensor([[-1.5125, 0.9334]], grad_fn=<AddmmBackward>)),
('interaction', tensor([[-0.0687, -0.1462]], grad_fn=<MulBackward0>)),
('nested', tensor([[-0.1697, 0.1432, 0.2959]], grad_fn=<AddmmBackward>))])
# model_output is None if keep_ouput is False
# if keep_output is True the model_output contains the final model's output
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Close
Hashes for torch_intermediate_layer_getter-0.1.post1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | c0e8374528d30f85e2420f6104242c0ca0495cfd7cdc551285305c01a7a21b67 |
|
MD5 | 6bd3245a597e7e0b4c620b1a2413f641 |
|
BLAKE2b-256 | 38988a37ff086257cdc9fd3e62f47b76de7d0091e9a43f3c719521411068449a |