Skip to main content

Summary of PyTorch Models just like `model.summary() in Keras

Project description

PyTorch Model Parameters Summary

Install using pip

pip install pytorchsummary

Example 1

from torch import nn
from pytorchsummary import parameter_summary

class CNNET(nn.Module):
    def __init__(self):
        super(CNNET,self).__init__()

        self.layer = nn.Sequential(
            nn.Conv2d(3,16,5), # 28-5+1
            nn.ReLU(), #24
            nn.MaxPool2d(2,2), # 12

            nn.Conv2d(16,32,3), # 12+1-3
            nn.ReLU(), # 10
            nn.MaxPool2d(2,2), # 5
            

            nn.Conv2d(32,64,5), # 11-3+1
            nn.ReLU(),

            nn.Conv2d(64,10,1)   
        )
    
    def forward(self,x):
        x = self.layer(x)
        return x

m = CNNET()
parameter_summary(m,False) 
for i,j in enumerate(m.parameters()):
    if i==2:
        break
    j.requires_grad=False 
# parameter_summary(model=m,border=False)
# if border set to True then it will print 
# the lines in between every layer 

Output

LAYER TYPE                   KERNEL SHAPE     #parameters        (weights+bias)         requires_grad         
____________________________________________________________________________________________________
 Conv2d-1                  [16, 3, 5, 5]    	1,216                (1200 + 16)          False False          
 ReLU-2                          -          	-                          -                               
 MaxPool2d-3                     -          	-                          -                               
 Conv2d-4                  [32, 16, 3, 3]   	4,640                (4608 + 32)           True True           
 ReLU-5                          -          	-                          -                               
 MaxPool2d-6                     -          	-                          -                               
 Conv2d-7                  [64, 32, 5, 5]   	51,264               (51200 + 64)           True True           
 ReLU-8                          -          	-                          -                               
 Conv2d-9                  [10, 64, 1, 1]   	650                 (640 + 10)           True True           
====================================================================================================

Total parameters 57,770
Total Non-Trainable parameters 1,216
Total Trainable parameters 56,554

57770

Example 2

from torchvision import models
from pytorchsummary import parameter_summary

m = models.alexnet(False)
parameter_summary(m)
# this function returns the total number of 
# parameters (int) in a model

ouput

LAYER TYPE                   KERNEL SHAPE     #parameters        (weights+bias)         requires_grad         
____________________________________________________________________________________________________
____________________________________________________________________________________________________
 Conv2d-1                 [64, 3, 11, 11]   	23,296               (23232 + 64)           True True           
____________________________________________________________________________________________________
 ReLU-2                          -          	-                          -                               
____________________________________________________________________________________________________
 MaxPool2d-3                     -          	-                          -                               
____________________________________________________________________________________________________
 Conv2d-4                 [192, 64, 5, 5]   	307,392             (307200 + 192)           True True           
____________________________________________________________________________________________________
 ReLU-5                          -          	-                          -                               
____________________________________________________________________________________________________
 MaxPool2d-6                     -          	-                          -                               
____________________________________________________________________________________________________
 Conv2d-7                 [384, 192, 3, 3]  	663,936             (663552 + 384)           True True           
____________________________________________________________________________________________________
 ReLU-8                          -          	-                          -                               
____________________________________________________________________________________________________
 Conv2d-9                 [256, 384, 3, 3]  	884,992             (884736 + 256)           True True           
____________________________________________________________________________________________________
 ReLU-10                         -          	-                          -                               
____________________________________________________________________________________________________
 Conv2d-11                [256, 256, 3, 3]  	590,080             (589824 + 256)           True True           
____________________________________________________________________________________________________
 ReLU-12                         -          	-                          -                               
____________________________________________________________________________________________________
 MaxPool2d-13                    -          	-                          -                               
____________________________________________________________________________________________________
 AdaptiveAvgPool2d-14            -          	-                          -                               
____________________________________________________________________________________________________
 Dropout-15                      -          	-                          -                               
____________________________________________________________________________________________________
 Linear-16                  [4096, 9216]    	37,752,832          (37748736 + 4096)           True True           
____________________________________________________________________________________________________
 ReLU-17                         -          	-                          -                               
____________________________________________________________________________________________________
 Dropout-18                      -          	-                          -                               
____________________________________________________________________________________________________
 Linear-19                  [4096, 4096]    	16,781,312          (16777216 + 4096)           True True           
____________________________________________________________________________________________________
 ReLU-20                         -          	-                          -                               
____________________________________________________________________________________________________
 Linear-21                  [1000, 4096]    	4,097,000           (4096000 + 1000)           True True           
====================================================================================================

Total parameters 61,100,840
Total Non-Trainable parameters 0
Total Trainable parameters 61,100,840

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorchsummary-1.0.4.tar.gz (3.8 kB view hashes)

Uploaded Source

Built Distribution

pytorchsummary-1.0.4-py3-none-any.whl (4.2 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page