Skip to main content

Enhanced PyTorch neural network architecture visualization with flowchart diagrams. Install with 'pip install pytorch-graph', import with 'import pytorch_graph'

Project description

PyTorch Graph

Professional PyTorch neural network visualization toolkit with complete computational graph analysis. Transform your PyTorch models into publication-ready diagrams with comprehensive architecture visualization and computational graph tracking.

Python 3.8+ PyTorch License: MIT PyPI version Documentation Demo

Documentation

Demo

Try the library immediately with our included demo:

# Clone the repository
git clone https://github.com/PluvioXO/pytorch-graph.git
cd pytorch-graph

# Run the demo
python3 demo.py

The demo generates:

  • Architecture diagrams (flowchart + research paper styles)
  • Computational graph visualization
  • Model analysis with parameter counts and memory usage
  • Professional quality output ready for publications

Output files are saved to demo_outputs/ directory.

Key Features

Architecture Visualization

  • Professional Flowchart Diagrams: Clean, vertical flowchart visualization with enhanced styling
  • Research Paper Style: Publication-ready diagrams with academic formatting
  • Multiple Export Formats: High-quality PNG with customizable DPI (up to 300 DPI)
  • Comprehensive Layer Analysis: Parameter counts, memory usage, and tensor shape transformations

Complete Computational Graph Analysis

  • Maximal Graph Traversal: Captures the entire computational graph without artificial limits
  • Full Operation Names: Displays complete method and object names (no truncation)
  • Smart Arrow Positioning: Arrows connect node edges properly without crossing over boxes
  • Compact Layout: Eliminates gaps and breaks for continuous graph flow
  • Real-time Execution Tracking: Monitors forward/backward passes and tensor operations

Professional Quality

  • Enhanced Color Schemes: Color-coded operation types and layer categories
  • Intelligent Legend Positioning: Automatic legend placement without overlap
  • Memory Analysis: Per-layer and total memory usage estimates
  • Model Complexity Assessment: Automatic size classification (Small/Medium/Large)
  • Data Flow Visualization: Tensor sizes and shapes displayed on connections

Installation

# Basic installation
pip install pytorch-graph

# With enhanced features
pip install pytorch-graph[full]

# Development version
pip install pytorch-graph[dev]

Quick Start

Architecture Visualization

import torch
import torch.nn as nn
from pytorch_graph import generate_architecture_diagram

# Define your model
model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 10)
)

# Generate professional architecture diagram
generate_architecture_diagram(
    model=model,
    input_shape=(1, 784),
    output_path="model_architecture.png",
    title="Neural Network Architecture"
)

Complete Computational Graph Analysis

from pytorch_graph import ComputationalGraphTracker
import torch

# Create a tracker for your model
tracker = ComputationalGraphTracker(model)

# Start tracking
tracker.start_tracking()

# Run your model
input_tensor = torch.randn(1, 784)
output = model(input_tensor)

# Stop tracking and save complete graph
tracker.stop_tracking()
tracker.save_graph_png("complete_computational_graph.png")

Comprehensive Examples

CNN Architecture Visualization

# Convolutional Neural Network
cnn_model = nn.Sequential(
    nn.Conv2d(3, 32, 3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(32, 64, 3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(64 * 8 * 8, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

# Generate multiple diagram styles
generate_architecture_diagram(
    cnn_model, 
    input_shape=(1, 3, 32, 32),
    output_path="cnn_flowchart.png",
    style="flowchart"
)

generate_architecture_diagram(
    cnn_model, 
    input_shape=(1, 3, 32, 32),
    output_path="cnn_research.png",
    style="research_paper"
)

Advanced Computational Graph Analysis

from pytorch_graph import track_computational_graph, analyze_computational_graph

# Track complete computational graph
tracker = track_computational_graph(
    model=model,
    input_tensor=input_tensor,
    track_memory=True,
    track_timing=True,
    track_tensor_ops=True
)

# Save high-quality computational graph
tracker.save_graph_png(
    filepath="complete_graph.png",
    width=1600,
    height=1200,
    dpi=300,
    show_legend=True
)

# Analyze the computational graph
analysis = analyze_computational_graph(model, input_tensor, detailed=True)
print(f"Total operations: {analysis['summary']['total_nodes']}")
print(f"Execution time: {analysis['summary']['execution_time']:.4f}s")

Diagram Styles

Enhanced Flowchart (Default)

generate_architecture_diagram(model, input_shape, "flowchart.png", style="flowchart")

Features:

  • Lightning bolt icons for activation functions
  • Memory usage per layer (e.g., "~1.2MB")
  • Data flow indicators on arrows
  • Summary panel with total parameters and memory
  • Color-coded model complexity

Research Paper Style

generate_architecture_diagram(model, input_shape, "paper.png", style="research_paper")

Features:

  • Academic formatting and typography
  • Clean, minimal design
  • Publication-ready quality
  • Professional color scheme

Standard Style

generate_architecture_diagram(model, input_shape, "standard.png", style="standard")

Features:

  • Classic neural network visualization
  • Balanced information density
  • Traditional layout

Computational Graph Features

Complete Graph Traversal

  • No Artificial Limits: Traverses entire autograd graph without depth/operation restrictions
  • Cycle Detection: Prevents infinite recursion while capturing complete structure
  • Full Operation Coverage: Shows every operation in the computational graph

Smart Visualization

  • Full Method Names: Displays complete operation names without truncation
  • Proper Arrow Connections: Arrows connect node edges without crossing over boxes
  • Compact Layout: Eliminates empty depth levels for continuous flow
  • Enhanced Spacing: Optimized node positioning and spacing

Professional Output

  • High-Resolution Images: Up to 300 DPI for publication quality
  • Intelligent Legends: Automatic positioning without overlap
  • Color-Coded Operations: Different colors for different operation types
  • Clean Typography: Professional fonts and text formatting

Model Analysis

from pytorch_graph import analyze_model

# Comprehensive model analysis
analysis = analyze_model(
    model=model,
    input_shape=(1, 784),
    detailed=True
)

print(f"Total Parameters: {analysis.get('total_params', 'N/A'):,}")
print(f"Trainable Parameters: {analysis.get('trainable_params', 'N/A'):,}")
print(f"Model Size: {analysis.get('model_size_mb', 'N/A'):.2f} MB")
print(f"Layer Count: {analysis.get('layer_count', 'N/A')}")

Advanced Configuration

Custom Computational Graph Settings

# Create tracker with custom settings
tracker = ComputationalGraphTracker(
    model=model,
    track_memory=True,      # Track memory usage
    track_timing=True,      # Track execution timing
    track_tensor_ops=True   # Track tensor operations
)

# Save with custom parameters
tracker.save_graph_png(
    filepath="custom_graph.png",
    width=2000,             # Custom width
    height=1500,            # Custom height
    dpi=300,                # High DPI
    show_legend=True,       # Show legend
    node_size=25,           # Node size
    font_size=12            # Font size
)

Architecture Diagram Customization

generate_architecture_diagram(
    model=model,
    input_shape=(1, 784),
    output_path="custom_architecture.png",
    title="Custom Model Architecture",
    style="flowchart",
    dpi=300,
    show_legend=True
)

Visual Examples

Architecture Diagrams

  • Flowchart Style: Professional vertical flow with enhanced information
  • Research Paper Style: Clean, academic formatting
  • Standard Style: Traditional neural network visualization

Computational Graphs

  • Complete Graph: Full autograd traversal without breaks
  • Smart Layout: Compact positioning with proper arrow connections
  • Full Names: Complete operation names without truncation

Performance Features

  • Memory Tracking: Real-time memory usage monitoring
  • Execution Timing: Performance analysis and timing
  • Tensor Operations: Complete tensor operation tracking
  • Optimized Rendering: Fast diagram generation
  • Efficient Layout: Smart positioning algorithms

Requirements

  • Python: ≥ 3.8
  • PyTorch: ≥ 1.8.0
  • matplotlib: ≥ 3.3.0
  • numpy: ≥ 1.19.0

API Reference

Core Functions

  • generate_architecture_diagram(): Create architecture diagrams
  • track_computational_graph(): Track computational graph execution
  • analyze_computational_graph(): Analyze graph structure and performance
  • analyze_model(): Comprehensive model analysis

Classes

  • ComputationalGraphTracker: Complete computational graph tracking
  • GraphNode: Individual graph node representation
  • GraphEdge: Graph edge representation

Contributing

We welcome contributions! Please see our Contributing Guidelines for details.

Development Setup

git clone https://github.com/your-username/pytorch-graph.git
cd pytorch-graph
pip install -e .[dev]

License

MIT License - see LICENSE file for details.

Acknowledgments

  • Built for the PyTorch community
  • Inspired by the need for better model visualization tools
  • Designed for researchers, practitioners, and educators

Support


PyTorch Graph - Professional PyTorch model visualization made simple, beautiful, and comprehensive.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_graph-0.2.5.tar.gz (50.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_graph-0.2.5-py3-none-any.whl (55.1 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_graph-0.2.5.tar.gz.

File metadata

  • Download URL: pytorch_graph-0.2.5.tar.gz
  • Upload date:
  • Size: 50.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for pytorch_graph-0.2.5.tar.gz
Algorithm Hash digest
SHA256 ae785cf8f30b40caec9e84a4e7dc0778c6c49ff7c496c3509c4ad8b5d7b35d73
MD5 a48cf6ac715282424b8488a17f642f2f
BLAKE2b-256 1ede53ff5c89d5a15760762c5304b9650d4b6da4c0af39c686919bf1b1dd8a7c

See more details on using hashes here.

File details

Details for the file pytorch_graph-0.2.5-py3-none-any.whl.

File metadata

  • Download URL: pytorch_graph-0.2.5-py3-none-any.whl
  • Upload date:
  • Size: 55.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for pytorch_graph-0.2.5-py3-none-any.whl
Algorithm Hash digest
SHA256 3d73afc1d4daad37d82acd023f8a764ff5535fc0c19bb3c2a9373308a2ae7912
MD5 2224b5bd2df658fa606a479ac7f8fef1
BLAKE2b-256 d22f9c6f2c95e0bb6f1d6a40c117bb44270fca6fde947e56ee115705504f0c1f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page