Skip to main content

Summary of PyTorch Models just like `model.summary() in Keras

Project description

PyTorch Model Parameters Summary

Install using pip

pip install modelsummary-pytorch==1.0.0

Example 1

from torch import nn
from pytorchsummary import parameter_summary

class CNNET(nn.Module):
    def __init__(self):
        super(CNNET,self).__init__()

        self.layer = nn.Sequential(
            nn.Conv2d(3,16,5), # 28-5+1
            nn.ReLU(), #24
            nn.MaxPool2d(2,2), # 12

            nn.Conv2d(16,32,3), # 12+1-3
            nn.ReLU(), # 10
            nn.MaxPool2d(2,2), # 5
            

            nn.Conv2d(32,64,5), # 11-3+1
            nn.ReLU(),

            nn.Conv2d(64,10,1)   
        )
    
    def forward(self,x):
        x = self.layer(x)
        return x

m = CNNET()
parameter_summary(m,False) 
# parameter_summary(model=m,border=False)
# if border set to True then it will print 
# the lines in between every layer 

Output

LAYER TYPE                   KERNEL SHAPE     #parameters        (weights+bias)
____________________________________________________________________________________________________
 Conv2d-2                  [16, 3, 5, 5]    	1,216                (1200 + 16)
 ReLU-3                          -          	-                          -
 MaxPool2d-4                     -          	-                          -
 Conv2d-5                  [32, 16, 3, 3]   	4,640                (4608 + 32)
 ReLU-6                          -          	-                          -
 MaxPool2d-7                     -          	-                          -
 Conv2d-8                  [64, 32, 5, 5]   	51,264               (51200 + 64)
 ReLU-9                          -          	-                          -
 Conv2d-10                 [10, 64, 1, 1]   	650                 (640 + 10)
====================================================================================================

Total parameters 57,770

57770

Example 2

from torchvision import models
from pytorchsummary import parameter_summary

m = models.alexnet(False)
parameter_summary(m)
# this function returns the total number of 
# parameters (int) in a model

ouput

LAYER TYPE                   KERNEL SHAPE     #parameters        (weights+bias)
____________________________________________________________________________________________________
____________________________________________________________________________________________________
 Conv2d-2                 [64, 3, 11, 11]       23,296               (23232 + 64)
____________________________________________________________________________________________________
 ReLU-3                          -              -                          -
____________________________________________________________________________________________________
 MaxPool2d-4                     -              -                          -
____________________________________________________________________________________________________
 Conv2d-5                 [192, 64, 5, 5]       307,392             (307200 + 192)
____________________________________________________________________________________________________
 ReLU-6                          -              -                          -
____________________________________________________________________________________________________
 MaxPool2d-7                     -              -                          -
____________________________________________________________________________________________________
 Conv2d-8                 [384, 192, 3, 3]      663,936             (663552 + 384)
____________________________________________________________________________________________________
 ReLU-9                          -              -                          -
____________________________________________________________________________________________________
 Conv2d-10                [256, 384, 3, 3]      884,992             (884736 + 256)
____________________________________________________________________________________________________
 ReLU-11                         -              -                          -
____________________________________________________________________________________________________
 Conv2d-12                [256, 256, 3, 3]      590,080             (589824 + 256)
____________________________________________________________________________________________________
 ReLU-13                         -              -                          -
____________________________________________________________________________________________________
 MaxPool2d-14                    -              -                          -
____________________________________________________________________________________________________
 AdaptiveAvgPool2d-15            -              -                          -
____________________________________________________________________________________________________
 Dropout-17                      -              -                          -
____________________________________________________________________________________________________
 Linear-18                  [4096, 9216]        37,752,832          (37748736 + 4096)
____________________________________________________________________________________________________
 ReLU-19                         -              -                          -
____________________________________________________________________________________________________
 Dropout-20                      -              -                          -
____________________________________________________________________________________________________
 Linear-21                  [4096, 4096]        16,781,312          (16777216 + 4096)
____________________________________________________________________________________________________
 ReLU-22                         -              -                          -
____________________________________________________________________________________________________
 Linear-23                  [1000, 4096]        4,097,000           (4096000 + 1000)
====================================================================================================

Total parameters 61,100,840

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorchsummary-1.0.3.tar.gz (3.6 kB view hashes)

Uploaded Source

Built Distribution

pytorchsummary-1.0.3-py3-none-any.whl (4.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page