Architecture visualization of Torch models
Project description
⭐ VisualTorch ⭐
VisualTorch aims to help visualize Torch-based neural network architectures. Currently, this package supports generating layered-style architectures for Torch Sequential models. This package is adapted from visualkeras by @paulgavrikov.
v0.1: Support for layered architecture of torch Sequential.
Installation
Install from PyPI (Latest release)
pip install visualtorch
Install from source
pip install git+https://github.com/willyfh/visualtorch
Usage
Display with Legend
import visualtorch
import torch.nn as nn
# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Flatten(),
nn.Linear(64 * 28 * 28, 256), # Adjusted the input size for the Linear layer
nn.ReLU(),
nn.Linear(256, 10) # Assuming 10 output classes
)
input_shape = (1, 3, 224, 224)
visualtorch.layered_view(model, input_shape=input_shape, legend=True).show() # display using your system viewer
Save the Image
visualtorch.layered_view(model, input_shape=input_shape, legend=True, to_file='output.png')
2D View
visualtorch.layered_view(model, input_shape=input_shape, draw_volume=False)
Custom Color
Use 'fill' to change the color of the layer, and use 'outline' to change the color of the lines.
from collections import defaultdict
color_map = defaultdict(dict)
color_map[nn.Conv2d]['fill'] = '#FF6F61' # Coral red
color_map[nn.ReLU]['fill'] = 'skyblue'
color_map[nn.MaxPool2d]['fill'] = '#88B04B' # Sage green
color_map[nn.Flatten]['fill'] = 'gold'
color_map[nn.Linear]['fill'] = '#6B5B95' # Royal purple
input_shape = (1, 3, 224, 224)
visualtorch.layered_view(model, input_shape=input_shape, color_map=color_map
Contributing
Please feel free to send a pull request to contribute to this project.
License
This poject is available as open source under the terms of the MIT License.
Originally, this project was based on the visualkeras (under the MIT license).
Citation
Please cite this project in your publications if it helps your research as follows:
@misc{Hendria2024VisualTorch,
author = {Hendria, Willy Fitra},
title = {visualtorch},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
note = {\url{https://github.com/willyfh/visualtorch}},
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file visualtorch-0.1.1.tar.gz
.
File metadata
- Download URL: visualtorch-0.1.1.tar.gz
- Upload date:
- Size: 9.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 53a04760ae9526ef984fac5c52b84aec1ab52e4f317e6a0204336356bc21a2c6 |
|
MD5 | cfefc529cb0184bf6114731968f6f87f |
|
BLAKE2b-256 | a614f435301a68951ab1ddb7a0e88a40d78646143ce89ab565f1c66a2a321fd1 |
File details
Details for the file visualtorch-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: visualtorch-0.1.1-py3-none-any.whl
- Upload date:
- Size: 9.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b962470351a3a0f6969bc30c1297ce2b1d9b633e3b394ae46b25ba4d6a924652 |
|
MD5 | 62c9431f3d98eb85eccd911b6d560c8d |
|
BLAKE2b-256 | c356b845fdea25db2c938f4fd252524c0e5bc993fafe19ea84f91d1150a3a26e |