Skip to main content

Framework-agnostic neural network architecture visualization for Jupyter notebooks

Project description

modelviz

modelviz-ai

Framework-agnostic neural network visualization for Jupyter notebooks

DocumentationFeaturesInstallationQuick StartExamples3D VisualizationAPIContributing


modelviz generates beautiful, publication-ready neural network architecture diagrams from your PyTorch and TensorFlow/Keras models. Simply pass your model object and get a stunning visualization — no manual diagram creation required.

✨ Features

Feature Description
🔍 Auto-detection Automatically detects PyTorch and TensorFlow/Keras models
📊 2D Diagrams Clean Graphviz diagrams with layer types, shapes, and parameters
🎮 3D Interactive Stunning Three.js visualizations with distinct shapes per layer
🔄 Skip Connections ResNet-style residual paths, dense connections, and branching architectures
🎨 Smart Styling Color-coded nodes for Conv, Linear, Pooling, Activation layers
📦 Block Grouping Auto-merges common patterns (Conv+ReLU, Conv+BN+ReLU)
📓 Notebook-native Renders inline in Jupyter, Colab, and VSCode notebooks
💾 Export Save as PNG, SVG, PDF, or interactive HTML

🎮 3D Visualization Preview

Each layer type has a distinct, meaningful 3D representation:

Layer Shape Rationale
Conv2d 3D Box Feature maps are 3D volumes (C×H×W)
Linear Flat Plane Weight matrix is 2D
Pooling Small Cube Reduces spatial dimensions
Activation Sphere Element-wise uniform operation
BatchNorm Thin Slab Normalizes distribution
Flatten Cone Funnels data to 1D
Dropout Wireframe Sparse/dropped neurons
RNN/LSTM Cylinder Recurrent/cyclical flow
Attention Octahedron Multi-head patterns

🚀 Installation

From PyPI

# Basic installation
pip install modelviz-ai

# With PyTorch support
pip install modelviz-ai[torch]

# With TensorFlow support
pip install modelviz-ai[tf]

# All frameworks + development tools
pip install modelviz-ai[all,dev]

From Source

git clone https://github.com/shreyanshjain05/modelviz.git
cd modelviz
pip install -e ".[dev]"

System Requirements

For 2D Graphviz diagrams, install the Graphviz system package:

# macOS
brew install graphviz

# Ubuntu/Debian
sudo apt-get install graphviz

# Windows (or use Conda)
conda install -c conda-forge graphviz

Note: Three.js 3D visualizations work without any system dependencies.

🎯 Quick Start

2D Visualization (Graphviz)

import torch.nn as nn
from modelviz import visualize

model = nn.Sequential(
    nn.Conv2d(1, 32, 3),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(32 * 13 * 13, 10)
)

# Renders inline in Jupyter
visualize(model, input_shape=(1, 1, 28, 28))

# Save to file
visualize(model, input_shape=(1, 1, 28, 28), save_path="model.png")

3D Visualization (Three.js)

from modelviz import visualize_threejs

# Creates an interactive HTML file
visualize_threejs(
    model,
    input_shape=(1, 1, 28, 28),
    save_path="model_3d.html"
)
# Open model_3d.html in your browser!

📖 Examples

PyTorch CNN

import torch.nn as nn
from modelviz import visualize, visualize_threejs

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 8 * 8, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 10),
        )
    
    def forward(self, x):
        return self.classifier(self.features(x))

model = CNN()

# 2D diagram with layer grouping
visualize(model, input_shape=(1, 3, 32, 32), title="CNN Architecture")

# 3D interactive visualization
visualize_threejs(model, input_shape=(1, 3, 32, 32), save_path="cnn_3d.html")

TensorFlow/Keras

import tensorflow as tf
from modelviz import visualize

model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(32, 3, activation='relu'),
    tf.keras.layers.MaxPooling2D(2),
    tf.keras.layers.Conv2D(64, 3, activation='relu'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10, activation='softmax'),
])

# No input_shape needed - Keras models are already built
visualize(model, save_path="keras_model.svg")

🎮 3D Visualization

The Three.js renderer creates stunning interactive 3D diagrams:

from modelviz import visualize_threejs

html = visualize_threejs(
    model,
    input_shape=(1, 3, 224, 224),
    title="ResNet Block",
    show_shapes=True,      # Show tensor dimensions
    show_params=True,      # Show parameter counts
    group_blocks=True,     # Merge Conv+BN+ReLU
    save_path="resnet.html"
)

Controls

Action Control
Rotate Drag mouse
Zoom Scroll wheel
Pan Shift + Drag
Details Hover over layer

Features

  • Horizontal layout — Data flows left to right
  • Text labels — Layer type and output shape above each node
  • Animated particles — Shows data flow between layers
  • Hover tooltips — Full layer information on mouseover
  • Legend — Color and shape guide

⚙️ API Reference

visualize()

Generate a 2D Graphviz diagram.

visualize(
    model,                          # PyTorch or Keras model
    input_shape=(1, 3, 224, 224),  # Required for PyTorch
    framework="auto",               # "auto", "pytorch", "tensorflow"
    show_shapes=True,               # Show output tensor shapes
    show_params=True,               # Show parameter counts
    group_blocks=True,              # Merge Conv+ReLU patterns
    save_path="model.png",          # Optional: save to file
    title="My Model",               # Optional: diagram title
) -> graphviz.Digraph

visualize_threejs()

Generate an interactive 3D Three.js visualization.

visualize_threejs(
    model,                          # PyTorch or Keras model
    input_shape=(1, 3, 224, 224),  # Required for PyTorch
    framework="auto",               # "auto", "pytorch", "tensorflow"
    show_shapes=True,               # Show shapes in labels
    show_params=True,               # Show params in tooltips
    group_blocks=True,              # Merge Conv+ReLU patterns
    save_path="model.html",         # Save as HTML file
    title="My Model 3D",            # Visualization title
) -> str  # Returns HTML string

visualize_3d()

Generate a Plotly 3D visualization (simpler fallback).

visualize_3d(
    model,
    input_shape=(1, 3, 224, 224),
    layout="tower",                 # "tower", "spiral", "grid"
    save_path="model.png",
) -> plotly.graph_objects.Figure

🎨 Styling

2D Node Colors (Graphviz)

Layer Type Color Hex
Convolution Indigo #6366f1
Linear/Dense Purple #8b5cf6
Pooling Cyan #06b6d4
Activation Amber #f59e0b
Normalization Emerald #10b981
Flatten Pink #ec4899
Dropout Red #ef4444
Embedding Lime #84cc16
RNN/LSTM Teal #14b8a6
Attention Orange #f97316

Block Grouping

Common patterns are automatically merged:

  • Conv2dBatchNorm2dReLUConv2d + BatchNorm2d + ReLU
  • Conv2dReLUConv2d + ReLU
  • LinearReLULinear + ReLU
  • DenseActivationDense + Activation

Disable with group_blocks=False.

🏗️ Architecture

modelviz/
├── modelviz/
│   ├── __init__.py              # Public API
│   ├── visualize.py             # Main API functions
│   ├── graph/
│   │   ├── layer_node.py        # LayerNode dataclass
│   │   └── builder.py           # Graph construction
│   ├── parsers/
│   │   ├── torch_parser.py      # PyTorch model parsing
│   │   ├── tf_parser.py         # TensorFlow/Keras parsing
│   │   └── fx_tracer.py         # Skip connection detection (NEW)
│   ├── renderers/
│   │   ├── graphviz_renderer.py # 2D Graphviz output
│   │   ├── plotly_renderer.py   # 3D Plotly output
│   │   └── threejs_renderer.py  # 3D Three.js output
│   └── utils/
│       ├── framework_detect.py  # Auto-detection
│       └── grouping.py          # Layer pattern grouping
├── tests/                       # Test suite
├── examples/                    # Demo scripts
├── docs/                        # Documentation
└── pyproject.toml              # Package config

🧪 Testing

# Run all tests
pytest tests/ -v

# With coverage
pytest tests/ --cov=modelviz --cov-report=html

# Run specific test
pytest tests/test_grouping.py -v

🗺️ Roadmap

  • Branching graph support (ResNet, UNet skip connections)
  • Transformer attention pattern visualization
  • Interactive web dashboard
  • Custom color themes
  • Model comparison (side-by-side)
  • FLOPs/MACs calculation
  • ONNX model support

🤝 Contributing

We welcome contributions! See CONTRIBUTING.md for guidelines.

Quick Start

git clone https://github.com/shreyanshjain05/modelviz.git
cd modelviz
python -m venv .venv
source .venv/bin/activate
pip install -e ".[dev,torch,tf]"
pytest tests/ -v

Code Style

  • Python 3.10+
  • Type hints on all public functions
  • Google-style docstrings
  • Black + isort formatting

📄 License

MIT License - see LICENSE for details.

🙏 Acknowledgments


Made with ❤️ for the deep learning community

⭐ Star this repo if you find it useful!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

modelviz_ai-0.1.0.tar.gz (36.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

modelviz_ai-0.1.0-py3-none-any.whl (36.9 kB view details)

Uploaded Python 3

File details

Details for the file modelviz_ai-0.1.0.tar.gz.

File metadata

  • Download URL: modelviz_ai-0.1.0.tar.gz
  • Upload date:
  • Size: 36.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for modelviz_ai-0.1.0.tar.gz
Algorithm Hash digest
SHA256 1ec27e738f8342537b84583f571a7594112201b1367011ac0303c81a80764159
MD5 006b0210a2444004fe2192a722e1745c
BLAKE2b-256 6338454252e12121c25d81a8d53135805962e22ea2cf5f59252885774aedb65b

See more details on using hashes here.

Provenance

The following attestation bundles were made for modelviz_ai-0.1.0.tar.gz:

Publisher: publish.yml on shreyanshjain05/modelviz

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file modelviz_ai-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: modelviz_ai-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 36.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for modelviz_ai-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 496c330f4b4eb791772a6fc2ea1afec9944f9a6b6a25e6e280ee0ce08966200a
MD5 414d6ccd982423bf223a7bfef16fd023
BLAKE2b-256 1cf829ae16716fc7e4be45ce37447853327fc29d92d7cfb56323f474a40fcb57

See more details on using hashes here.

Provenance

The following attestation bundles were made for modelviz_ai-0.1.0-py3-none-any.whl:

Publisher: publish.yml on shreyanshjain05/modelviz

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page