Architecture visualization of Torch models
Project description
VisualTorch aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style and graph-style architectures for PyTorch Sequential and Custom models. This tool is adapted from visualkeras, pytorchviz, and pytorch-summary.
Note: VisualTorch may not yet support complex models, but contributions are welcome!
v0.2: Added support for custom models and implemented graph view functionality.
v0.1.1: Added support for the layered architecture of Torch Sequential.
Installation
Install from PyPI (Latest release)
pip install visualtorch
Install from source
pip install git+https://github.com/willyfh/visualtorch
Usage
Sequential
import visualtorch
import torch.nn as nn
# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Flatten(),
nn.Linear(64 * 28 * 28, 256), # Adjusted the input size for the Linear layer
nn.ReLU(),
nn.Linear(256, 10) # Assuming 10 output classes
)
input_shape = (1, 3, 224, 224)
visualtorch.layered_view(model, input_shape=input_shape, legend=True).show() # display using your system viewer
Custom Model
In a custom model, only the components defined within the model's init method are visualized. The operations that are defined exclusively within the forward function are not visualized.
import torch.nn as nn
import torch.nn.functional as F
import visualtorch
# Example of a simple CNN model
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.fc1 = nn.Linear(64 * 28 * 28, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
x = self.conv3(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
print(x.shape)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
# Create an instance of the SimpleCNN
model = SimpleCNN()
input_shape = (1, 3, 224, 224)
visualtorch.layered_view(model, input_shape=input_shape, legend=True).show() # display using your system viewer
Graph View
import torch
import torch.nn as nn
import visualtorch
class SimpleDense(nn.Module):
def __init__(self):
super(SimpleDense, self).__init__()
self.h0 = nn.Linear(4, 8)
self.h1 = nn.Linear(8, 8)
self.h2 = nn.Linear(8, 4)
self.out = nn.Linear(4, 2)
def forward(self, x):
x = self.h0(x)
x = self.h1(x)
x = self.h2(x)
x = self.out(x)
return x
model = SimpleDense()
input_shape = (1, 4)
visualtorch.graph_view(model, input_shape).show()
Save the Image
visualtorch.layered_view(model, input_shape=input_shape, legend=True, to_file='output.png')
2D View
visualtorch.layered_view(model, input_shape=input_shape, draw_volume=False)
Custom Color
Use 'fill' to change the color of the layer, and use 'outline' to change the color of the lines.
from collections import defaultdict
color_map = defaultdict(dict)
color_map[nn.Conv2d]['fill'] = 'LightSlateGray' # Light Slate Gray
color_map[nn.ReLU]['fill'] = '#87CEFA' # Light Sky Blue
color_map[nn.MaxPool2d]['fill'] = 'LightSeaGreen' # Light Sea Green
color_map[nn.Flatten]['fill'] = '#98FB98' # Pale Green
color_map[nn.Linear]['fill'] = 'LightSteelBlue' # Light Steel Blue
input_shape = (1, 3, 224, 224)
visualtorch.layered_view(model, input_shape=input_shape, color_map=color_map
Contributing
Please feel free to send a pull request to contribute to this project.
License
This poject is available as open source under the terms of the MIT License.
Originally, this project was based on the visualkeras (under the MIT license), with additional modifications inspired by pytorchviz, and pytorch-summary, both of which are also licensed under the MIT license.
Citation
Please cite this project in your publications if it helps your research as follows:
@misc{Hendria2024VisualTorch,
author = {Hendria, Willy Fitra},
title = {visualtorch},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
note = {\url{https://github.com/willyfh/visualtorch}},
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file visualtorch-0.2.0.tar.gz
.
File metadata
- Download URL: visualtorch-0.2.0.tar.gz
- Upload date:
- Size: 15.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 70dfeea727d4c5c13cffd89b1bb100c265d889404eec1c7916221dec9af9b81f |
|
MD5 | 72c40f252d159c92858b0efe91e08d0a |
|
BLAKE2b-256 | 1329f89fa4872738b49992d8b680dfb7d40b78b0a2127559f747ce5d65ccdcf0 |
File details
Details for the file visualtorch-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: visualtorch-0.2.0-py3-none-any.whl
- Upload date:
- Size: 15.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 366648d3690b0778e9f6930b1e44b10116c664e570e1d396d4ac1a3b791ca169 |
|
MD5 | 220bd4c4fdbf79f016c84b5731c5c898 |
|
BLAKE2b-256 | c8cf1fd12c83a57dd20bb24e0110aee54c92efdddd13153774d339f9f6387957 |