Skip to main content

Architecture visualization of Torch models

Project description

🔥 VisualTorch 🔥

python pytorch Downloads Run Tests

VisualTorch aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style and graph-style architectures for PyTorch Sequential and Custom models. This tool is adapted from visualkeras, pytorchviz, and pytorch-summary.

Note: VisualTorch may not yet support complex models, but contributions are welcome!

layered-and-graph

v0.2: Added support for custom models and implemented graph view functionality.

v0.1.1: Added support for the layered architecture of Torch Sequential.

Installation

Install from PyPI (Latest release)

pip install visualtorch

Install from source

pip install git+https://github.com/willyfh/visualtorch

Usage

Sequential

import visualtorch
import torch.nn as nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10)  # Assuming 10 output classes
)

input_shape = (1, 3, 224, 224)

visualtorch.layered_view(model, input_shape=input_shape, legend=True).show() # display using your system viewer

simple-cnn

Custom Model

In a custom model, only the components defined within the model's init method are visualized. The operations that are defined exclusively within the forward function are not visualized.

import torch.nn as nn
import torch.nn.functional as F
import visualtorch

# Example of a simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        print(x.shape)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

# Create an instance of the SimpleCNN
model = SimpleCNN()

input_shape = (1, 3, 224, 224)

visualtorch.layered_view(model, input_shape=input_shape, legend=True).show() # display using your system viewer

simple-cnn-custom

Graph View

import torch
import torch.nn as nn
import visualtorch

class SimpleDense(nn.Module):
    def __init__(self):
        super(SimpleDense, self).__init__()
        self.h0 = nn.Linear(4, 8)
        self.h1 = nn.Linear(8, 8)
        self.h2 = nn.Linear(8, 4)
        self.out = nn.Linear(4, 2)

    def forward(self, x):
        x = self.h0(x)
        x = self.h1(x)
        x = self.h2(x)
        x = self.out(x)
        return x

model = SimpleDense()

input_shape = (1, 4)

visualtorch.graph_view(model, input_shape).show()

graph

Save the Image

visualtorch.layered_view(model, input_shape=input_shape, legend=True, to_file='output.png')

2D View

visualtorch.layered_view(model, input_shape=input_shape, draw_volume=False)

2d-view

Custom Color

Use 'fill' to change the color of the layer, and use 'outline' to change the color of the lines.

from collections import defaultdict

color_map = defaultdict(dict)
color_map[nn.Conv2d]['fill'] = 'LightSlateGray' # Light Slate Gray
color_map[nn.ReLU]['fill'] = '#87CEFA' # Light Sky Blue
color_map[nn.MaxPool2d]['fill'] = 'LightSeaGreen' # Light Sea Green
color_map[nn.Flatten]['fill'] = '#98FB98' # Pale Green
color_map[nn.Linear]['fill'] = 'LightSteelBlue' # Light Steel Blue

input_shape = (1, 3, 224, 224)
visualtorch.layered_view(model, input_shape=input_shape, color_map=color_map

custom-color

Contributing

Please feel free to send a pull request to contribute to this project.

License

This poject is available as open source under the terms of the MIT License.

Originally, this project was based on the visualkeras (under the MIT license), with additional modifications inspired by pytorchviz, and pytorch-summary, both of which are also licensed under the MIT license.

Citation

Please cite this project in your publications if it helps your research as follows:

@misc{Hendria2024VisualTorch,
  author = {Hendria, Willy Fitra},
  title = {visualtorch},
  year = {2024},
  publisher = {GitHub},
  journal = {GitHub repository},
  note = {\url{https://github.com/willyfh/visualtorch}},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

visualtorch-0.2.0.tar.gz (15.1 kB view details)

Uploaded Source

Built Distribution

visualtorch-0.2.0-py3-none-any.whl (15.1 kB view details)

Uploaded Python 3

File details

Details for the file visualtorch-0.2.0.tar.gz.

File metadata

  • Download URL: visualtorch-0.2.0.tar.gz
  • Upload date:
  • Size: 15.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.13

File hashes

Hashes for visualtorch-0.2.0.tar.gz
Algorithm Hash digest
SHA256 70dfeea727d4c5c13cffd89b1bb100c265d889404eec1c7916221dec9af9b81f
MD5 72c40f252d159c92858b0efe91e08d0a
BLAKE2b-256 1329f89fa4872738b49992d8b680dfb7d40b78b0a2127559f747ce5d65ccdcf0

See more details on using hashes here.

File details

Details for the file visualtorch-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: visualtorch-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 15.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.13

File hashes

Hashes for visualtorch-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 366648d3690b0778e9f6930b1e44b10116c664e570e1d396d4ac1a3b791ca169
MD5 220bd4c4fdbf79f016c84b5731c5c898
BLAKE2b-256 c8cf1fd12c83a57dd20bb24e0110aee54c92efdddd13153774d339f9f6387957

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page