Skip to main content

Describe PyTorch model in PyTorch way

Project description

Torch Describe

Describe PyTorch model in PyTorch way

If you want Keras style model.summary() then torchsummary is there. But it only tells you how tensors flows through your model. It does not tell you the real structure of your model (if you know what I mean).

In that case print(model) does a decent job. But it does not prints more information about model. That is where torchdescribe comes in.

Install

  • Using pip,

    pip install torchdescribe
    
  • From command line run,

    $ git clone https://github.com/vidit1999/pytorch-describe
    $ cd pytorch-describe
    $ python setup.py install
    

Usage

from torchdescribe import describe

input_shape = ... # input_shape(s) of model

class Model(nn.Module):
    def __init__(self,...):
        super(Model, self).__init__()

        ...
        # model attribute definations

    ...
    #  model methods
    ...

model = Model()
describe(model, input_shape)

Examples

  1. CNN for MNIST Classification

    import torch
    import torch.nn as nn
    
    from torchdescribe import describe
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    class CNN(nn.Module):
        def __init__(self):
            super(CNN, self).__init__()
            self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
            self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
            self.conv3 = nn.Conv2d(32,64, kernel_size=5)
            self.fc1 = nn.Linear(3*3*64, 256)
            self.fc2 = nn.Linear(256, 10)
    
        def forward(self, x):
            x = torch.relu(self.conv1(x))
            x = torch.relu(torch.max_pool2d(self.conv2(x), 2))
            x = torch.relu(torch.max_pool2d(self.conv3(x),2))
            x = x.view(-1,3*3*64 )
            x = torch.relu(self.fc1(x))
            x = self.fc2(x)
            return torch.log_softmax(x, dim=1)
    
    cnn = CNN().to(device)
    describe(model=cnn, input_shape=(1, 28, 28), batch_size=1000)
    
    -------------------------------------------------------------
                             CNN
    -------------------------------------------------------------
    =============================================================
    
    CNN(
        (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
        (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1))
        (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
        (fc1): Linear(in_features=576, out_features=256, bias=True)
        (fc2): Linear(in_features=256, out_features=10, bias=True)
    )
    
    =============================================================
    -------------------------------------------------------------
    Total parameters : 228,010
    Trainable parameters : 228,010
    Non-trainable parameters : 0
    -------------------------------------------------------------
    Model device : CPU
    Batch size : 1000
    Input shape : (1000, 1, 28, 28)
    Output shape : (1000, 10)
    Input size (MB) : 2.99
    Forward/backward pass size (MB) : 257.97
    Params size (MB) : 0.87
    Estimated Total Size (MB) : 261.83
    -------------------------------------------------------------
    
  2. VGG-16

    import torch
    from torchvision import models
    from torchdescribe import describe
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    vgg16 = models.vgg16().to(device)
    
    describe(vgg16, (3, 224, 224), 100)
    
    ------------------------------------------------------------------------------------
                                            VGG
    ------------------------------------------------------------------------------------
    ====================================================================================
    
    VGG(
        (features): Sequential(
            (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): ReLU(inplace=True)
            (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (3): ReLU(inplace=True)
            (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (6): ReLU(inplace=True)
            (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (8): ReLU(inplace=True)
            (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (11): ReLU(inplace=True)
            (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (13): ReLU(inplace=True)
            (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (15): ReLU(inplace=True)
            (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (18): ReLU(inplace=True)
            (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (20): ReLU(inplace=True)
            (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (22): ReLU(inplace=True)
            (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (25): ReLU(inplace=True)
            (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (27): ReLU(inplace=True)
            (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (29): ReLU(inplace=True)
            (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        )
        (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
        (classifier): Sequential(
            (0): Linear(in_features=25088, out_features=4096, bias=True)
            (1): ReLU(inplace=True)
            (2): Dropout(p=0.5, inplace=False)
            (3): Linear(in_features=4096, out_features=4096, bias=True)
            (4): ReLU(inplace=True)
            (5): Dropout(p=0.5, inplace=False)
            (6): Linear(in_features=4096, out_features=1000, bias=True)
        )
    )
    
    ====================================================================================
    ------------------------------------------------------------------------------------
    Total parameters : 138,357,544
    Trainable parameters : 138,357,544
    Non-trainable parameters : 0
    ------------------------------------------------------------------------------------
    Model device : CPU
    Batch size : 100
    Input shape : (100, 3, 224, 224)
    Output shape : (100, 1000)
    Input size (MB) : 57.42
    Forward/backward pass size (MB) : 21878.87
    Params size (MB) : 527.79
    Estimated Total Size (MB) : 22464.08
    ------------------------------------------------------------------------------------
    
  3. Multiple Inputs

    import torch
    import torch.nn as nn
    from torchdescribe import describe
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.layers = nn.Sequential(
                nn.Conv2d(1, 16, kernel_size=4, stride=2, padding=1),
                nn.ReLU(),
                nn.Flatten(),
            )
    
        def forward(self, x, y):
            a = self.layers(x)
            b = self.layers(y)
            return a, b
    
    
    net = Net().to(device)
    describe(net, [(1, 16, 16), (1, 28, 28)], 500)
    
    -------------------------------------------------------------------------
                                    Net
    -------------------------------------------------------------------------
    =========================================================================
    
    Net(
        (layers): Sequential(
            (0): Conv2d(1, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
            (1): ReLU()
            (2): Flatten(start_dim=1, end_dim=-1)
        )
    )
    
    =========================================================================
    -------------------------------------------------------------------------
    Total parameters : 272
    Trainable parameters : 272
    Non-trainable parameters : 0
    -------------------------------------------------------------------------
    Model device : CPU
    Batch size : 500
    Input shape : [(500, 1, 16, 16), (500, 1, 28, 28)]
    Output shape : [(500, 1024), (500, 3136)]
    Input size (MB) : 1.98
    Forward/backward pass size (MB) : 63.48
    Params size (MB) : 0.00
    Estimated Total Size (MB) : 65.46
    -------------------------------------------------------------------------
    
  4. To Suppress Errors

    import torch
    import torch.nn as nn
    from torchdescribe import describe
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    class RNN(nn.Module):
        def __init__(self):
            super(RNN, self).__init__()
            self.hidden_size = 20
    
            self.embedding = nn.Embedding(10, 20)
            self.gru = nn.GRU(20, 20)
    
        def forward(self, input, hidden):
            embedded = self.embedding(input).view(1, 1, -1)
            output = embedded
            output, hidden = self.gru(output, hidden)
            return output, hidden
    
    
    rnn = RNN().to(device)
    describe(rnn, (3, 6), suppress_error=True)
    
    --------------------------------------------------
                        RNN
    --------------------------------------------------
    ==================================================
    
    RNN(
        (embedding): Embedding(10, 20)
        (gru): GRU(20, 20)
    )
    
    ==================================================
    --------------------------------------------------
    Total parameters : 2,720
    Trainable parameters : 2,720
    Non-trainable parameters : 0
    --------------------------------------------------
    Model device : CPU
    Batch size : 1
    Input shape : (1, 3, 6)
    Output shape : []
    Input size (MB) : -1
    Forward/backward pass size (MB) : -1
    Params size (MB) : 0.01
    Estimated Total Size (MB) : -1
    --------------------------------------------------
    

References

  • For Model Size Estimation help is taken from here.
  • This project is inspired by torchsummary.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchdescribe-0.0.1.tar.gz (9.0 kB view details)

Uploaded Source

Built Distribution

torchdescribe-0.0.1-py3-none-any.whl (7.0 kB view details)

Uploaded Python 3

File details

Details for the file torchdescribe-0.0.1.tar.gz.

File metadata

  • Download URL: torchdescribe-0.0.1.tar.gz
  • Upload date:
  • Size: 9.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.9.4

File hashes

Hashes for torchdescribe-0.0.1.tar.gz
Algorithm Hash digest
SHA256 7ab92d6fb95a12f8c4565a7e177ca93484c40ec83ae4b7f96dd6dd69622d3010
MD5 b475d2386006f3d4392eb37f65a2d2f3
BLAKE2b-256 0f5c7a61464235d1f3cfbfe8558e8509d3aa0f0e2dd4f6a1d77f9961897b8cbe

See more details on using hashes here.

File details

Details for the file torchdescribe-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: torchdescribe-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 7.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.9.4

File hashes

Hashes for torchdescribe-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 0b9762e3f0b4d4f59d067a0db1e6635ea0ccb0089f7d9da577786572bce775c6
MD5 8412366f7a435d458386eccdb78edada
BLAKE2b-256 5f237b1dc6e5cd5834174cd11d178912348af77ca3968449fe698ce80e314b25

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page