Skip to main content

Architecture visualization of Torch models

Project description

🔥 VisualTorch 🔥

python pytorch Downloads Run Tests

VisualTorch aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style and graph-style architectures for PyTorch Sequential and Custom models. This tool is adapted from visualkeras, pytorchviz, and pytorch-summary.

Note: VisualTorch may not yet support complex models, but contributions are welcome!

layered-and-graph

v0.2: Added support for custom models and implemented graph view functionality.

v0.1.1: Added support for the layered architecture of Torch Sequential.

Installation

Install from PyPI (Latest release)

pip install visualtorch

Install from source

pip install git+https://github.com/willyfh/visualtorch

Usage

Sequential

import visualtorch
import torch.nn as nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10)  # Assuming 10 output classes
)

input_shape = (1, 3, 224, 224)

visualtorch.layered_view(model, input_shape=input_shape, legend=True).show() # display using your system viewer

simple-cnn

Custom Model

In a custom model, only the components defined within the model's init method are visualized. The operations that are defined exclusively within the forward function are not visualized.

import torch.nn as nn
import torch.nn.functional as F
import visualtorch

# Example of a simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        print(x.shape)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

# Create an instance of the SimpleCNN
model = SimpleCNN()

input_shape = (1, 3, 224, 224)

visualtorch.layered_view(model, input_shape=input_shape, legend=True).show() # display using your system viewer

simple-cnn-custom

Graph View

import torch
import torch.nn as nn
import visualtorch

class SimpleDense(nn.Module):
    def __init__(self):
        super(SimpleDense, self).__init__()
        self.h0 = nn.Linear(4, 8)
        self.h1 = nn.Linear(8, 8)
        self.h2 = nn.Linear(8, 4)
        self.out = nn.Linear(4, 2)

    def forward(self, x):
        x = self.h0(x)
        x = self.h1(x)
        x = self.h2(x)
        x = self.out(x)
        return x

model = SimpleDense()

input_shape = (1, 4)

visualtorch.graph_view(model, input_shape).show()

graph

Save the Image

visualtorch.layered_view(model, input_shape=input_shape, legend=True, to_file='output.png')

2D View

visualtorch.layered_view(model, input_shape=input_shape, draw_volume=False)

2d-view

Custom Color

Use 'fill' to change the color of the layer, and use 'outline' to change the color of the lines.

from collections import defaultdict

color_map = defaultdict(dict)
color_map[nn.Conv2d]['fill'] = 'LightSlateGray' # Light Slate Gray
color_map[nn.ReLU]['fill'] = '#87CEFA' # Light Sky Blue
color_map[nn.MaxPool2d]['fill'] = 'LightSeaGreen' # Light Sea Green
color_map[nn.Flatten]['fill'] = '#98FB98' # Pale Green
color_map[nn.Linear]['fill'] = 'LightSteelBlue' # Light Steel Blue

input_shape = (1, 3, 224, 224)
visualtorch.layered_view(model, input_shape=input_shape, color_map=color_map

custom-color

Contributing

Please feel free to send a pull request to contribute to this project.

License

This poject is available as open source under the terms of the MIT License.

Originally, this project was based on the visualkeras (under the MIT license), with additional modifications inspired by pytorchviz, and pytorch-summary, both of which are also licensed under the MIT license.

Citation

Please cite this project in your publications if it helps your research as follows:

@misc{Hendria2024VisualTorch,
  author = {Hendria, Willy Fitra},
  title = {visualtorch},
  year = {2024},
  publisher = {GitHub},
  journal = {GitHub repository},
  note = {\url{https://github.com/willyfh/visualtorch}},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

visualtorch-0.2.0.tar.gz (15.1 kB view hashes)

Uploaded Source

Built Distribution

visualtorch-0.2.0-py3-none-any.whl (15.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page