Summary of PyTorch Models just like `model.summary() in Keras
Project description
PyTorch Model Parameters Summary
Install using pip
pip install modelsummary-pytorch==1.0.0
Example 1
from torch import nn
from pytorchsummary import parameter_summary
class CNNET(nn.Module):
def __init__(self):
super(CNNET,self).__init__()
self.layer = nn.Sequential(
nn.Conv2d(3,16,5), # 28-5+1
nn.ReLU(), #24
nn.MaxPool2d(2,2), # 12
nn.Conv2d(16,32,3), # 12+1-3
nn.ReLU(), # 10
nn.MaxPool2d(2,2), # 5
nn.Conv2d(32,64,5), # 11-3+1
nn.ReLU(),
nn.Conv2d(64,10,1)
)
def forward(self,x):
x = self.layer(x)
return x
m = CNNET()
parameter_summary(m,False)
# parameter_summary(model=m,border=False)
# if border set to True then it will print
# the lines in between every layer
Output
LAYER TYPE KERNEL SHAPE #parameters (weights+bias)
____________________________________________________________________________________________________
Conv2d-2 [16, 3, 5, 5] 1,216 (1200 + 16)
ReLU-3 - - -
MaxPool2d-4 - - -
Conv2d-5 [32, 16, 3, 3] 4,640 (4608 + 32)
ReLU-6 - - -
MaxPool2d-7 - - -
Conv2d-8 [64, 32, 5, 5] 51,264 (51200 + 64)
ReLU-9 - - -
Conv2d-10 [10, 64, 1, 1] 650 (640 + 10)
====================================================================================================
Total parameters 57,770
57770
Example 2
from torchvision import models
from pytorchsummary import parameter_summary
m = models.alexnet(False)
parameter_summary(m)
# this function returns the total number of
# parameters (int) in a model
ouput
LAYER TYPE KERNEL SHAPE #parameters (weights+bias)
____________________________________________________________________________________________________
____________________________________________________________________________________________________
Conv2d-2 [64, 3, 11, 11] 23,296 (23232 + 64)
____________________________________________________________________________________________________
ReLU-3 - - -
____________________________________________________________________________________________________
MaxPool2d-4 - - -
____________________________________________________________________________________________________
Conv2d-5 [192, 64, 5, 5] 307,392 (307200 + 192)
____________________________________________________________________________________________________
ReLU-6 - - -
____________________________________________________________________________________________________
MaxPool2d-7 - - -
____________________________________________________________________________________________________
Conv2d-8 [384, 192, 3, 3] 663,936 (663552 + 384)
____________________________________________________________________________________________________
ReLU-9 - - -
____________________________________________________________________________________________________
Conv2d-10 [256, 384, 3, 3] 884,992 (884736 + 256)
____________________________________________________________________________________________________
ReLU-11 - - -
____________________________________________________________________________________________________
Conv2d-12 [256, 256, 3, 3] 590,080 (589824 + 256)
____________________________________________________________________________________________________
ReLU-13 - - -
____________________________________________________________________________________________________
MaxPool2d-14 - - -
____________________________________________________________________________________________________
AdaptiveAvgPool2d-15 - - -
____________________________________________________________________________________________________
Dropout-17 - - -
____________________________________________________________________________________________________
Linear-18 [4096, 9216] 37,752,832 (37748736 + 4096)
____________________________________________________________________________________________________
ReLU-19 - - -
____________________________________________________________________________________________________
Dropout-20 - - -
____________________________________________________________________________________________________
Linear-21 [4096, 4096] 16,781,312 (16777216 + 4096)
____________________________________________________________________________________________________
ReLU-22 - - -
____________________________________________________________________________________________________
Linear-23 [1000, 4096] 4,097,000 (4096000 + 1000)
====================================================================================================
Total parameters 61,100,840
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pytorchsummary-1.0.3.tar.gz
(3.6 kB
view hashes)
Built Distribution
Close
Hashes for pytorchsummary-1.0.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 23eece8d8c3efb7fdcfead7bc457304fc1e209814398460263e42dc2ef277db0 |
|
MD5 | 4103458ecae269849ef20d793f6fc263 |
|
BLAKE2b-256 | 5d19c322b81031f6b905813b69c46d2d59af40db1e83407ae9a32c8ac346241a |