Skip to main content

Architecture visualization of Torch models

Project description

⭐ VisualTorch ⭐

VisualTorch aims to help visualize Torch-based neural network architectures. Currently, this package supports generating layered-style architectures for Torch Sequential models. This package is adapted from visualkeras by @paulgavrikov.

v0.1: Support for layered architecture of torch Sequential.

Installation

Install from PyPI (Latest release)

pip install visualtorch

Install from source

pip install git+https://github.com/willyfh/visualtorch

Usage

Display with Legend

import visualtorch
import torch.nn as nn

# Example of a simple CNN model using nn.Sequential
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 28 * 28, 256),  # Adjusted the input size for the Linear layer
    nn.ReLU(),
    nn.Linear(256, 10)  # Assuming 10 output classes
)

input_shape = (1, 3, 224, 224)

visualtorch.layered_view(model, input_shape=input_shape, legend=True).show() # display using your system viewer

simple-cnn

Save the Image

visualtorch.layered_view(model, input_shape=input_shape, legend=True, to_file='output.png')

2D View

visualtorch.layered_view(model, input_shape=input_shape, draw_volume=False)

2d-view

Custom Color

Use 'fill' to change the color of the layer, and use 'outline' to change the color of the lines.

from collections import defaultdict

color_map = defaultdict(dict)
color_map[nn.Conv2d]['fill'] = '#FF6F61' # Coral red
color_map[nn.ReLU]['fill'] = 'skyblue'
color_map[nn.MaxPool2d]['fill'] = '#88B04B' # Sage green
color_map[nn.Flatten]['fill'] = 'gold'
color_map[nn.Linear]['fill'] = '#6B5B95'    # Royal purple

input_shape = (1, 3, 224, 224)
visualtorch.layered_view(model, input_shape=input_shape, color_map=color_map

custom-color

Contributing

Please feel free to send a pull request to contribute to this project.

License

This poject is available as open source under the terms of the MIT License.

Originally, this project was based on the visualkeras (under the MIT license).

Citation

Please cite this project in your publications if it helps your research as follows:

@misc{Hendria2024VisualTorch,
  author = {Hendria, Willy Fitra},
  title = {visualtorch},
  year = {2024},
  publisher = {GitHub},
  journal = {GitHub repository},
  note = {\url{https://github.com/willyfh/visualtorch}},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

visualtorch-0.1.1.tar.gz (9.7 kB view details)

Uploaded Source

Built Distribution

visualtorch-0.1.1-py3-none-any.whl (9.5 kB view details)

Uploaded Python 3

File details

Details for the file visualtorch-0.1.1.tar.gz.

File metadata

  • Download URL: visualtorch-0.1.1.tar.gz
  • Upload date:
  • Size: 9.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.13

File hashes

Hashes for visualtorch-0.1.1.tar.gz
Algorithm Hash digest
SHA256 53a04760ae9526ef984fac5c52b84aec1ab52e4f317e6a0204336356bc21a2c6
MD5 cfefc529cb0184bf6114731968f6f87f
BLAKE2b-256 a614f435301a68951ab1ddb7a0e88a40d78646143ce89ab565f1c66a2a321fd1

See more details on using hashes here.

File details

Details for the file visualtorch-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: visualtorch-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 9.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.13

File hashes

Hashes for visualtorch-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b962470351a3a0f6969bc30c1297ce2b1d9b633e3b394ae46b25ba4d6a924652
MD5 62c9431f3d98eb85eccd911b6d560c8d
BLAKE2b-256 c356b845fdea25db2c938f4fd252524c0e5bc993fafe19ea84f91d1150a3a26e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page