Skip to main content
Join the official 2019 Python Developers SurveyStart the survey!

Utility functions that prints a summary of a model.

Project description

torch-inspect

https://travis-ci.com/jettify/pytorch-inspect.svg?branch=master https://codecov.io/gh/jettify/pytorch-inspect/branch/master/graph/badge.svg https://img.shields.io/pypi/pyversions/torch-inspect.svg https://img.shields.io/pypi/v/torch-inspect.svg

torch-inspect – collection of utility functions to inspect low level information of neural network for PyTorch

Features

  • Provides helper function summary that prints Keras style model summary.
  • Provides helper function inspect that returns object with network summary information for programmatic access.
  • RNN/LSTM support.
  • Library has tests and reasonable code coverage.

Simple example

import torch.nn as nn
import torch.nn.functional as F
import torch_inspect as ti

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 6 * 6, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


  net = SimpleNet()
  ti.summary(net, (1, 32, 32))

Will produce following output:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [100, 6, 30, 30]              60
            Conv2d-2          [100, 16, 13, 13]             880
            Linear-3                 [100, 120]          69,240
            Linear-4                  [100, 84]          10,164
            Linear-5                  [100, 10]             850
================================================================
Total params: 81,194
Trainable params: 81,194
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.39
Forward/backward pass size (MB): 6.35
Params size (MB): 0.31
Estimated Total Size (MB): 7.05
----------------------------------------------------------------

For programmatic access to network information there is inspect function:

info = ti.inspect(net, (1, 32, 32))
print(info)
[LayerInfo(name='Conv2d-1', input_shape=[100, 1, 32, 32], output_shape=[100, 6, 30, 30], trainable_params=60, non_trainable_params=0),
 LayerInfo(name='Conv2d-2', input_shape=[100, 6, 15, 15], output_shape=[100, 16, 13, 13], trainable_params=880, non_trainable_params=0),
 LayerInfo(name='Linear-3', input_shape=[100, 576], output_shape=[100, 120], trainable_params=69240, non_trainable_params=0),
 LayerInfo(name='Linear-4', input_shape=[100, 120], output_shape=[100, 84], trainable_params=10164, non_trainable_params=0),
 LayerInfo(name='Linear-5', input_shape=[100, 84], output_shape=[100, 10], trainable_params=850, non_trainable_params=0)]

Installation

Installation process is simple, just:

$ pip install torch-inspect

Requirements

References and Thanks

This package is based on pytorch-summary and PyTorch issue . Compared to pytorch-summary, pytorch-inspect has support of RNN/LSTMs, also provides programmatic access to the network summary information. With a bit more modular structure and presence of tests it is easier to extend and support more features.

Changes

0.0.3 (2019-09-22)

  • Added LSTM support
  • Fixed multi input/output support
  • Added more network test cases
  • Batch size no longer -1 by default

0.0.2 (2019-09-22)

  • Added batch norm support
  • Removed device parameter

0.0.1 (2019-09-1)

  • Initial release.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for torch-inspect, version 0.0.3
Filename, size File type Python version Upload date Hashes
Filename, size torch-inspect-0.0.3.tar.gz (14.5 kB) File type Source Python version None Upload date Hashes View hashes

Supported by

Elastic Elastic Search Pingdom Pingdom Monitoring Google Google BigQuery Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN SignalFx SignalFx Supporter DigiCert DigiCert EV certificate StatusPage StatusPage Status page