Skip to main content

Pytorch model summary table containing layer sizes and shapes.

Project description

torchinfo (beta)

Please use https://pypi.org/project/torch-summary/ for now, thanks!

Python 3.6+ PyPI version Build Status GitHub license codecov Downloads

Torchinfo provides information complementary to what is provided by print(your_model) in PyTorch, similar to Tensorflow's model.summary() API to view the visualization of the model, which is helpful while debugging your network. In this project, we implement a similar functionality in PyTorch and create a clean, simple interface to use in your projects.

This is a completely rewritten version of the original torchsummary and torchsummaryX projects by @sksq96 and @nmhkahn. This project addresses all of the issues and pull requests left on the original projects by introducing a completely new API.

Usage

pip install torchinfo

How To Use

from torchinfo import summary

model = ConvNet()
summary(model, (1, 28, 28), batch_dim=0)
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Conv2d: 1-1                            [-1, 10, 24, 24]          260
├─Conv2d: 1-2                            [-1, 20, 8, 8]            5,020
├─Dropout2d: 1-3                         [-1, 20, 8, 8]            --
├─Linear: 1-4                            [-1, 50]                  16,050
├─Linear: 1-5                            [-1, 10]                  510
==========================================================================================
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.05
Params size (MB): 0.08
Estimated Total Size (MB): 0.14
==========================================================================================

This version now supports:

  • RNNs, LSTMs, and other recursive layers
  • Sequentials & Module Lists
  • Branching output used to explore model layers using specified depths
  • Returns ModelStatistics object containing all summary data fields
  • Configurable columns

Other new features:

  • Verbose mode to show weights and bias layers
  • Accepts either input data or simply the input shape!
  • Customizable widths and batch dimension
  • Comprehensive unit/output testing, linting, and code coverage testing

Documentation

"""
Summarize the given PyTorch model. Summarized information includes:
    1) Layer names,
    2) input/output shapes,
    3) kernel shape,
    4) # of parameters,
    5) # of operations (Mult-Adds)

Args:
    model (nn.Module):
            PyTorch model to summarize

    input_data (Sequence of Sizes or Tensors):
            Example input tensor of the model (dtypes inferred from model input).
            - OR -
            Shape of input data as a List/Tuple/torch.Size
            (dtypes must match model input, default is FloatTensors).
            You should NOT include batch size in the tuple.
            - OR -
            If input_data is not provided, no forward pass through the network is
            performed, and the provided model information is limited to layer names.
            Default: None

    batch_dim (int):
            Batch_dimension of input data. If batch_dim is None, assume input data
            contains the batch dimension, which is used in all calculations.
            Default: None

    branching (bool):
            Whether to use the branching layout for the printed output.
            Default: True

    col_names (Sequence[str]):
            Specify which columns to show in the output. Currently supported:
            ("input_size", "output_size", "num_params", "kernel_size", "mult_adds")
            If input_data is not provided, only "num_params" is used.
            Default: ("output_size", "num_params")

    col_width (int):
            Width of each column.
            Default: 25

    depth (int):
            Number of nested layers to traverse (e.g. Sequentials).
            Default: 3

    device (torch.Device):
            Uses this torch device for model and input_data.
            If not specified, uses result of torch.cuda.is_available().
            Default: None

    dtypes (List[torch.dtype]):
            For multiple inputs, specify the size of both inputs, and
            also specify the types of each parameter here.
            Default: None

    verbose (int):
            0 (quiet): No output
            1 (default): Print model summary
            2 (verbose): Show weight and bias layers in full detail
            Default: 1

    *args, **kwargs:
            Other arguments used in `model.forward` function.

Return:
    ModelStatistics object
            See torchinfo/model_statistics.py for more information.
"""

Examples

Get Model Summary as String

from torchinfo import summary

model_stats = summary(your_model, (3, 28, 28), verbose=0)
summary_str = str(model_stats)
# summary_str contains the string representation of the summary. See below for examples.

ResNet

import torchvision

model = torchvision.models.resnet50()
summary(model, (3, 224, 224), depth=3)
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Conv2d: 1-1                            [-1, 64, 112, 112]        9,408
├─BatchNorm2d: 1-2                       [-1, 64, 112, 112]        128
├─ReLU: 1-3                              [-1, 64, 112, 112]        --
├─MaxPool2d: 1-4                         [-1, 64, 56, 56]          --
├─Sequential: 1-5                        [-1, 256, 56, 56]         --
|    └─Bottleneck: 2-1                   [-1, 256, 56, 56]         --
|    |    └─Conv2d: 3-1                  [-1, 64, 56, 56]          4,096
|    |    └─BatchNorm2d: 3-2             [-1, 64, 56, 56]          128
|    |    └─ReLU: 3-3                    [-1, 64, 56, 56]          --
|    |    └─Conv2d: 3-4                  [-1, 64, 56, 56]          36,864
|    |    └─BatchNorm2d: 3-5             [-1, 64, 56, 56]          128
|    |    └─ReLU: 3-6                    [-1, 64, 56, 56]          --
|    |    └─Conv2d: 3-7                  [-1, 256, 56, 56]         16,384
|    |    └─BatchNorm2d: 3-8             [-1, 256, 56, 56]         512
|    |    └─Sequential: 3-9              [-1, 256, 56, 56]         --
|    |    └─ReLU: 3-10                   [-1, 256, 56, 56]         --

  ...
  ...
  ...

├─AdaptiveAvgPool2d: 1-9                 [-1, 2048, 1, 1]          --
├─Linear: 1-10                           [-1, 1000]                2,049,000
==========================================================================================
Total params: 60,192,808
Trainable params: 60,192,808
Non-trainable params: 0
Total mult-adds (G): 11.63
==========================================================================================
Input size (MB): 0.57
Forward/backward pass size (MB): 344.16
Params size (MB): 229.62
Estimated Total Size (MB): 574.35
==========================================================================================

Multiple Inputs w/ Different Data Types

class MultipleInputNetDifferentDtypes(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1a = nn.Linear(300, 50)
        self.fc1b = nn.Linear(50, 10)

        self.fc2a = nn.Linear(300, 50)
        self.fc2b = nn.Linear(50, 10)

    def forward(self, x1, x2):
        x1 = F.relu(self.fc1a(x1))
        x1 = self.fc1b(x1)
        x2 = x2.type(torch.float)
        x2 = F.relu(self.fc2a(x2))
        x2 = self.fc2b(x2)
        x = torch.cat((x1, x2), 0)
        return F.log_softmax(x, dim=1)

summary(model, [(1, 300), (1, 300)], dtypes=[torch.float, torch.long])

Alternatively, you can also pass in the input_data itself, and torchinfo will automatically infer the data types.

input_data = torch.randn(1, 300)
other_input_data = torch.randn(1, 300).long()
model = MultipleInputNetDifferentDtypes()

summary(model, input_data, other_input_data, ...)

Explore Different Configurations

class LSTMNet(nn.Module):
    """ Batch-first LSTM model. """
    def __init__(self, vocab_size=20, embed_dim=300, hidden_dim=512, num_layers=2):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.encoder = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True)
        self.decoder = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embed = self.embedding(x)
        out, hidden = self.encoder(embed)
        out = self.decoder(out)
        out = out.view(-1, out.size(2))
        return out, hidden

summary(
    LSTMNet(),
    (100,),
    dtypes=[torch.long],
    branching=False,
    verbose=2,
    col_width=16,
    col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
)
========================================================================================================================
Layer (type:depth-idx)                   Kernel Shape         Output Shape         Param #              Mult-Adds
========================================================================================================================
Embedding: 1-1                           [300, 20]            [-1, 100, 300]       6,000                6,000
LSTM: 1-2                                --                   [-1, 100, 512]        3,768,320            3,760,128
  weight_ih_l0                           [2048, 300]
  weight_hh_l0                           [2048, 512]
  weight_ih_l1                           [2048, 512]
  weight_hh_l1                           [2048, 512]
Linear: 1-3                              [512, 20]            [-1, 100, 20]        10,260               10,240
========================================================================================================================
Total params: 3,784,580
Trainable params: 3,784,580
Non-trainable params: 0
Total mult-adds (M): 3.78
========================================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 1.03
Params size (MB): 14.44
Estimated Total Size (MB): 15.46
========================================================================================================================

Sequentials & ModuleLists

class ContainerModule(nn.Module):
    """ Model using ModuleList. """

    def __init__(self):
        super().__init__()
        self._layers = nn.ModuleList()
        self._layers.append(nn.Linear(5, 5))
        self._layers.append(ContainerChildModule())
        self._layers.append(nn.Linear(5, 5))

    def forward(self, x):
        for layer in self._layers:
            x = layer(x)
        return x


class ContainerChildModule(nn.Module):
    """ Model using Sequential in different ways. """

    def __init__(self):
        super().__init__()
        self._sequential = nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
        self._between = nn.Linear(5, 5)

    def forward(self, x):
        out = self._sequential(x)
        out = self._between(out)
        for l in self._sequential:
            out = l(out)

        out = self._sequential(x)
        for l in self._sequential:
            out = l(out)
        return out

summary(ContainerModule(), (5,))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─ModuleList: 1                          []                        --
|    └─Linear: 2-1                       [-1, 5]                   30
|    └─ContainerChildModule: 2-2         [-1, 5]                   --
|    |    └─Sequential: 3-1              [-1, 5]                   --
|    |    |    └─Linear: 4-1             [-1, 5]                   30
|    |    |    └─Linear: 4-2             [-1, 5]                   30
|    |    └─Linear: 3-2                  [-1, 5]                   30
|    |    └─Sequential: 3                []                        --
|    |    |    └─Linear: 4-3             [-1, 5]                   (recursive)
|    |    |    └─Linear: 4-4             [-1, 5]                   (recursive)
|    |    └─Sequential: 3-3              [-1, 5]                   (recursive)
|    |    |    └─Linear: 4-5             [-1, 5]                   (recursive)
|    |    |    └─Linear: 4-6             [-1, 5]                   (recursive)
|    |    |    └─Linear: 4-7             [-1, 5]                   (recursive)
|    |    |    └─Linear: 4-8             [-1, 5]                   (recursive)
|    └─Linear: 2-3                       [-1, 5]                   30
==========================================================================================
Total params: 150
Trainable params: 150
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================

Other Examples

================================================================
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1            [-1, 1, 16, 16]              10
              ReLU-2            [-1, 1, 16, 16]               0
            Conv2d-3            [-1, 1, 28, 28]              10
              ReLU-4            [-1, 1, 28, 28]               0
================================================================
Total params: 20
Trainable params: 20
Non-trainable params: 0
================================================================
Input size (MB): 0.77
Forward/backward pass size (MB): 0.02
Params size (MB): 0.00
Estimated Total Size (MB): 0.78
================================================================

Future Plans

  • Support all types of inputs - showing tuples and dict inputs cleanly rather than only using the first tensor in the list.
  • FunctionalNet unused; figure out a way to hook into functional layers.

Contributing

All issues and pull requests are much appreciated! If you are wondering how to build the project:

  • torchinfo is actively developed using the lastest version of Python.
    • Changes should be backward compatible with Python 3.6, but this is subject to change in the future.
    • Run pip install -r requirements-dev.txt. We use the latest versions of all dev packages.
    • First, be sure to run ./scripts/install-hooks
    • To run all tests and use auto-formatting tools, check out scripts/run-tests.
    • To only run unit tests, run pytest.

References

  • Thanks to @sksq96, @nmhkahn, and @sangyx for providing the original code this project was based off of.
  • For Model Size Estimation @jacobkimmel (details here)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchinfo-0.0.2.tar.gz (17.9 kB view details)

Uploaded Source

Built Distribution

torchinfo-0.0.2-py3-none-any.whl (15.3 kB view details)

Uploaded Python 3

File details

Details for the file torchinfo-0.0.2.tar.gz.

File metadata

  • Download URL: torchinfo-0.0.2.tar.gz
  • Upload date:
  • Size: 17.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/47.1.0 requests-toolbelt/0.9.1 tqdm/4.49.0 CPython/3.8.5

File hashes

Hashes for torchinfo-0.0.2.tar.gz
Algorithm Hash digest
SHA256 508b5d5e40a408caa0e63aa89610ade2ee4c6e6405af654292d8025bb43c6d57
MD5 651bd141b80d6005c130458b2f7d5fd6
BLAKE2b-256 cfbcf4b4f803666707db95b7691a691d0f9263daf56edb2bffc61b0fc45099f6

See more details on using hashes here.

File details

Details for the file torchinfo-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: torchinfo-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 15.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/47.1.0 requests-toolbelt/0.9.1 tqdm/4.49.0 CPython/3.8.5

File hashes

Hashes for torchinfo-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 da4c04f115253dbd324fa9fbf4dfb9353b80b07d2a4587151528f1a65d57e338
MD5 5abbd0c92cde705121d3c3256bfb16e0
BLAKE2b-256 29b50e2763af01ebb16d400b1ac152fc60bc37f2a84d327dac482061158f4414

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page