Enhanced PyTorch neural network architecture visualization with flowchart diagrams. Install with 'pip install pytorch-graph', import with 'import pytorch_graph'
Project description
PyTorch Graph
Professional PyTorch neural network visualization toolkit with complete computational graph analysis. Transform your PyTorch models into publication-ready diagrams with comprehensive architecture visualization and computational graph tracking.
Documentation
- Complete Documentation - Full API reference, guides, and examples
- Quick Start Guide - Get started in minutes
- API Reference - Complete function and class documentation
- Examples - Comprehensive usage examples
Demo
Try the library immediately with our included demo:
# Clone the repository
git clone https://github.com/PluvioXO/pytorch-graph.git
cd pytorch-graph
# Run the demo
python3 demo.py
The demo generates:
- Architecture diagrams (flowchart + research paper styles)
- Computational graph visualization
- Model analysis with parameter counts and memory usage
- Professional quality output ready for publications
Output files are saved to demo_outputs/ directory.
Key Features
Architecture Visualization
- Professional Flowchart Diagrams: Clean, vertical flowchart visualization with enhanced styling
- Research Paper Style: Publication-ready diagrams with academic formatting
- Multiple Export Formats: High-quality PNG with customizable DPI (up to 300 DPI)
- Comprehensive Layer Analysis: Parameter counts, memory usage, and tensor shape transformations
Complete Computational Graph Analysis
- Maximal Graph Traversal: Captures the entire computational graph without artificial limits
- Full Operation Names: Displays complete method and object names (no truncation)
- Smart Arrow Positioning: Arrows connect node edges properly without crossing over boxes
- Compact Layout: Eliminates gaps and breaks for continuous graph flow
- Real-time Execution Tracking: Monitors forward/backward passes and tensor operations
Professional Quality
- Enhanced Color Schemes: Color-coded operation types and layer categories
- Intelligent Legend Positioning: Automatic legend placement without overlap
- Memory Analysis: Per-layer and total memory usage estimates
- Model Complexity Assessment: Automatic size classification (Small/Medium/Large)
- Data Flow Visualization: Tensor sizes and shapes displayed on connections
Installation
# Basic installation
pip install pytorch-graph
# With enhanced features
pip install pytorch-graph[full]
# Development version
pip install pytorch-graph[dev]
Quick Start
Architecture Visualization
import torch
import torch.nn as nn
from pytorch_graph import generate_architecture_diagram
# Define your model
model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10)
)
# Generate professional architecture diagram
generate_architecture_diagram(
model=model,
input_shape=(1, 784),
output_path="model_architecture.png",
title="Neural Network Architecture"
)
Complete Computational Graph Analysis
from pytorch_graph import ComputationalGraphTracker
import torch
# Create a tracker for your model
tracker = ComputationalGraphTracker(model)
# Start tracking
tracker.start_tracking()
# Run your model
input_tensor = torch.randn(1, 784)
output = model(input_tensor)
# Stop tracking and save complete graph
tracker.stop_tracking()
tracker.save_graph_png("complete_computational_graph.png")
Comprehensive Examples
CNN Architecture Visualization
# Convolutional Neural Network
cnn_model = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64 * 8 * 8, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
# Generate multiple diagram styles
generate_architecture_diagram(
cnn_model,
input_shape=(1, 3, 32, 32),
output_path="cnn_flowchart.png",
style="flowchart"
)
generate_architecture_diagram(
cnn_model,
input_shape=(1, 3, 32, 32),
output_path="cnn_research.png",
style="research_paper"
)
Advanced Computational Graph Analysis
from pytorch_graph import track_computational_graph, analyze_computational_graph
# Track complete computational graph
tracker = track_computational_graph(
model=model,
input_tensor=input_tensor,
track_memory=True,
track_timing=True,
track_tensor_ops=True
)
# Save high-quality computational graph
tracker.save_graph_png(
filepath="complete_graph.png",
width=1600,
height=1200,
dpi=300,
show_legend=True
)
# Analyze the computational graph
analysis = analyze_computational_graph(model, input_tensor, detailed=True)
print(f"Total operations: {analysis['summary']['total_nodes']}")
print(f"Execution time: {analysis['summary']['execution_time']:.4f}s")
Diagram Styles
Enhanced Flowchart (Default)
generate_architecture_diagram(model, input_shape, "flowchart.png", style="flowchart")
Features:
- Lightning bolt icons for activation functions
- Memory usage per layer (e.g., "~1.2MB")
- Data flow indicators on arrows
- Summary panel with total parameters and memory
- Color-coded model complexity
Research Paper Style
generate_architecture_diagram(model, input_shape, "paper.png", style="research_paper")
Features:
- Academic formatting and typography
- Clean, minimal design
- Publication-ready quality
- Professional color scheme
Standard Style
generate_architecture_diagram(model, input_shape, "standard.png", style="standard")
Features:
- Classic neural network visualization
- Balanced information density
- Traditional layout
Computational Graph Features
Complete Graph Traversal
- No Artificial Limits: Traverses entire autograd graph without depth/operation restrictions
- Cycle Detection: Prevents infinite recursion while capturing complete structure
- Full Operation Coverage: Shows every operation in the computational graph
Smart Visualization
- Full Method Names: Displays complete operation names without truncation
- Proper Arrow Connections: Arrows connect node edges without crossing over boxes
- Compact Layout: Eliminates empty depth levels for continuous flow
- Enhanced Spacing: Optimized node positioning and spacing
Professional Output
- High-Resolution Images: Up to 300 DPI for publication quality
- Intelligent Legends: Automatic positioning without overlap
- Color-Coded Operations: Different colors for different operation types
- Clean Typography: Professional fonts and text formatting
Model Analysis
from pytorch_graph import analyze_model
# Comprehensive model analysis
analysis = analyze_model(
model=model,
input_shape=(1, 784),
detailed=True
)
print(f"Total Parameters: {analysis.get('total_params', 'N/A'):,}")
print(f"Trainable Parameters: {analysis.get('trainable_params', 'N/A'):,}")
print(f"Model Size: {analysis.get('model_size_mb', 'N/A'):.2f} MB")
print(f"Layer Count: {analysis.get('layer_count', 'N/A')}")
Advanced Configuration
Custom Computational Graph Settings
# Create tracker with custom settings
tracker = ComputationalGraphTracker(
model=model,
track_memory=True, # Track memory usage
track_timing=True, # Track execution timing
track_tensor_ops=True # Track tensor operations
)
# Save with custom parameters
tracker.save_graph_png(
filepath="custom_graph.png",
width=2000, # Custom width
height=1500, # Custom height
dpi=300, # High DPI
show_legend=True, # Show legend
node_size=25, # Node size
font_size=12 # Font size
)
Architecture Diagram Customization
generate_architecture_diagram(
model=model,
input_shape=(1, 784),
output_path="custom_architecture.png",
title="Custom Model Architecture",
style="flowchart",
dpi=300,
show_legend=True
)
Visual Examples
Architecture Diagrams
- Flowchart Style: Professional vertical flow with enhanced information
- Research Paper Style: Clean, academic formatting
- Standard Style: Traditional neural network visualization
Computational Graphs
- Complete Graph: Full autograd traversal without breaks
- Smart Layout: Compact positioning with proper arrow connections
- Full Names: Complete operation names without truncation
Performance Features
- Memory Tracking: Real-time memory usage monitoring
- Execution Timing: Performance analysis and timing
- Tensor Operations: Complete tensor operation tracking
- Optimized Rendering: Fast diagram generation
- Efficient Layout: Smart positioning algorithms
Requirements
- Python: ≥ 3.8
- PyTorch: ≥ 1.8.0
- matplotlib: ≥ 3.3.0
- numpy: ≥ 1.19.0
API Reference
Core Functions
generate_architecture_diagram(): Create architecture diagramstrack_computational_graph(): Track computational graph executionanalyze_computational_graph(): Analyze graph structure and performanceanalyze_model(): Comprehensive model analysis
Classes
ComputationalGraphTracker: Complete computational graph trackingGraphNode: Individual graph node representationGraphEdge: Graph edge representation
Contributing
We welcome contributions! Please see our Contributing Guidelines for details.
Development Setup
git clone https://github.com/your-username/pytorch-graph.git
cd pytorch-graph
pip install -e .[dev]
License
MIT License - see LICENSE file for details.
Acknowledgments
- Built for the PyTorch community
- Inspired by the need for better model visualization tools
- Designed for researchers, practitioners, and educators
Support
- Issues: GitHub Issues
- Discussions: GitHub Discussions
- Documentation: Full Documentation
PyTorch Graph - Professional PyTorch model visualization made simple, beautiful, and comprehensive.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pytorch_graph-0.2.5.tar.gz.
File metadata
- Download URL: pytorch_graph-0.2.5.tar.gz
- Upload date:
- Size: 50.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ae785cf8f30b40caec9e84a4e7dc0778c6c49ff7c496c3509c4ad8b5d7b35d73
|
|
| MD5 |
a48cf6ac715282424b8488a17f642f2f
|
|
| BLAKE2b-256 |
1ede53ff5c89d5a15760762c5304b9650d4b6da4c0af39c686919bf1b1dd8a7c
|
File details
Details for the file pytorch_graph-0.2.5-py3-none-any.whl.
File metadata
- Download URL: pytorch_graph-0.2.5-py3-none-any.whl
- Upload date:
- Size: 55.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3d73afc1d4daad37d82acd023f8a764ff5535fc0c19bb3c2a9373308a2ae7912
|
|
| MD5 |
2224b5bd2df658fa606a479ac7f8fef1
|
|
| BLAKE2b-256 |
d22f9c6f2c95e0bb6f1d6a40c117bb44270fca6fde947e56ee115705504f0c1f
|