A PyTorch model profiling tool for FLOPs, memory, and visualization.
Project description
📊 model_profiler
A lightweight PyTorch model profiler that reports FLOPs, memory usage, parameters, input/output shapes, and automatically exports results to Excel with colored tags.
It also integrates with Torchview to generate computation graph diagrams.
🚀 Features
- Profile per-layer FLOPs, memory (bytes), parameters, input/output shapes.
- Support three modes:
- raw: list all layers (Conv, BN, ReLU, …).
- cba: merge
Conv+BN+Activationinto CBA blocks. - block: merge into higher-level blocks (e.g., backbone / neck / head).
- Export results to Excel with:
- Full profiling table.
- Color-coded rows:
- 🔴 Memory-bound
- 🟢 Compute-bound
- ⚪ Balanced
- Automatic statistics sheet (summary counts & totals).
- Generate graph PNG with Torchview.
📦 Installation
Clone the repository and install:
git clone https://github.com/yourname/model_profiler.git
cd model_profiler
pip install -e .
or
pip install torch torchvision prettytable openpyxl torchview graphviz
# pip install model_profiler
🛠 Usage
Below is an example of profiling a model with profile_flops_and_memory_layername:
stats = profile_flops_and_memory_layername(
model,
input_size=(1, 3, 224, 224), # Dummy input size for the model
threshold_low=10, # Threshold for detecting memory-bound layers
threshold_high=100, # Threshold for detecting compute-bound layers
mode="raw" # Profiling mode: "raw", "cba", or "block"
)
Parameters:
-
model: The PyTorch model to be analyzed.
-
input_size: The shape of the dummy input tensor (batch, channels, height, width).
-
threshold_low: If FLOPs-to-Memory ratio < this value, the layer is considered memory-bound.
-
threshold_high: If FLOPs-to-Memory ratio > this value, the layer is considered compute-bound.
-
mode:
-
"raw" → Show every layer individually (Conv, BN, ReLU, etc.).(defult)
-
"cba" → Combine Conv+BN+Activation into a single CBA block.
-
"block" → Merge into large functional blocks (e.g., backbone, neck, head).
-
Output:
The function prints a detailed table in the console, and also returns a list (layer_stats) containing:
- Layer name
- Input shape / Output shape
- FLOPs
- Memory usage
- FLOPs-to-Memory ratio
- Number of parameters
- Tag (Memory-bound, Compute-bound, Balanced)
Example (mode='yaw'):
import torch
import torch.nn as nn
from model_profiler import (
profile_flops_and_memory_layername,
export_profile_to_excel,
draw_model_with_tags
)
# Define a simple CNN
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.bn1 = nn.BatchNorm2d(16)
self.relu1 = nn.ReLU()
self.pool = nn.MaxPool2d(2)
self.fc = nn.Linear(16*16*16, 10)
def forward(self, x):
x = self.pool(self.relu1(self.bn1(self.conv1(x))))
x = x.view(x.size(0), -1)
return self.fc(x)
# Create model
model = SimpleCNN()
# Run profiler
stats = profile_flops_and_memory_layername(model,
input_size=(1, 3, 32, 32),
mode="row")
# Export to Excel
export_profile_to_excel(stats, "cnn_profile.xlsx")
# Draw graph (requires Graphviz installed)
draw_model_with_tags(model, (1, 3, 32, 32), stats, filename="cnn_graph")
📊 Output in Command Line
| Layer(Name) | Input Shape | Output Shape | FLOPs (M) | Memory (KB) | FLOP/Byte | Params (K) | Tag |
|---|---|---|---|---|---|---|---|
| conv1 (Conv2d) | (1, 3, 32, 32) | (1, 16, 32, 32) | 0.29 | 12.3 | 23.5 | 448 | Balanced |
| bn1 (BatchNorm2d) | (1, 16, 32, 32) | (1, 16, 32, 32) | 0.03 | 8.0 | 0.5 | 32 | Memory-bound ❗ |
| relu1 (ReLU) | (1, 16, 32, 32) | (1, 16, 32, 32) | 0.02 | 8.0 | 0.2 | 0 | Memory-bound ❗ |
| pool (MaxPool2d) | (1, 16, 32, 32) | (1, 16, 16, 16) | 0.01 | 4.0 | 0.1 | 0 | Memory-bound ❗ |
| fc (Linear) | (1, 4096) | (1, 10) | 0.04 | 16.0 | 2.5 | 40 | Balanced |
📊 Excel Output Example
📊 Model structure as .png file
📌 Roadmap
-
Add this repository to pip install
-
Add CUDA memory usage profiling
-
Add latency measurement
-
Add visualization in Jupyter Notebook
📌 Problems you may encounter
⚠️ You must also install the Graphviz system package (see Graphviz Setup)
IF there are any problems in draw_model_with_tags, please check Graphviz Executable.
⚠️ Graphviz Executable Not Found
If you see an error like this:
graphviz.backend.execute.ExecutableNotFound: failed to execute WindowsPath('dot'),
make sure the Graphviz executables are on your systems' PATH
✅ Cause
The Python graphviz package is only a binding. It requires the actual Graphviz executables (especially dot) to be installed on your system and available in the PATH.
🔧 Solution: Graphviz Setup
-
Download Graphviz Get the Windows installer from the official site: 👉 https://graphviz.gitlab.io/_pages/Download/Download_windows.html Use the stable release installer (EXE).
-
Install Graphviz The default installation path is usually:
C:\Program Files\Graphviz\bin
- Add Graphviz to PATH Open Windows search → type Environment Variables. Edit the System Environment Variables → Path. Add:
C:\Program Files\Graphviz\bin
Click OK to save and close.
- Restart your terminal (Anaconda Prompt / CMD / PowerShell) so the new PATH takes effect.
📄 License
MIT License © 2025 Chih-Sheng (Tommy) Huang
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_model_profiler-0.1.0.tar.gz.
File metadata
- Download URL: torch_model_profiler-0.1.0.tar.gz
- Upload date:
- Size: 10.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
07f0e717a7487a800780095d2908749ea2dd881c617c6b1698c5accc023bd581
|
|
| MD5 |
fe840d79f662b4f0101c40e07ee7e47d
|
|
| BLAKE2b-256 |
dc18c36f90247344cc06007a59884e8ccbe5dc786af0563a65a4148074384ad3
|
File details
Details for the file torch_model_profiler-0.1.0-py3-none-any.whl.
File metadata
- Download URL: torch_model_profiler-0.1.0-py3-none-any.whl
- Upload date:
- Size: 9.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3147b2b7c76ac965800cb49e02301c6edd38b856e003f8b5825c4989fcc92531
|
|
| MD5 |
111ea56b7f99fcdab35178a0cbc1cd81
|
|
| BLAKE2b-256 |
860b9ac3fef0c11e4c04499592c9bca74b2769f13ec69d3493a97280f8563b77
|