Skip to main content

A PyTorch model profiling tool for FLOPs, memory, and visualization.

Project description

📊 torch-model-profiler

A lightweight PyTorch model profiler that reports FLOPs, memory usage, parameters, input/output shapes, and automatically exports results to Excel with colored tags.

It also integrates with Torchview to generate computation graph diagrams.


🚀 Features

  • Profile per-layer FLOPs, memory (bytes), parameters, input/output shapes.
  • Support three modes:
    • raw: list all layers (Conv, BN, ReLU, …).
    • cba: merge Conv+BN+Activation into CBA blocks.
    • block: merge into higher-level blocks (e.g., backbone / neck / head).
  • Export results to Excel with:
    • Full profiling table.
    • Color-coded rows:
      • 🔴 Memory-bound
      • 🟢 Compute-bound
      • ⚪ Balanced
    • Automatic statistics sheet (summary counts & totals).
  • Generate graph PNG with Torchview.

📦 Installation

Clone the repository and install:

git clone https://github.com/yourname/model_profiler.git
cd model_profiler
pip install -e .

or

pip install torch torchvision prettytable openpyxl torchview graphviz
pip install model_profiler

🛠 Usage

critial function in this module:

  1. profile_flops_and_memory_layername
  2. estimate_inference_time
  3. export_profile_to_excel_withinferencetime

Below is an example of profiling a model with profile_flops_and_memory_layername:

stats = profile_flops_and_memory_layername(
                model, 
                input_size=(1, 3, 224, 224),   # Dummy input size for the model
                threshold_low=10,              # Threshold for detecting memory-bound layers
                threshold_high=100,            # Threshold for detecting compute-bound layers
                mode="raw"                     # Profiling mode: "raw", "cba", or "block"
               )                                     

Parameters:

  • model: The PyTorch model to be analyzed.

  • input_size: The shape of the dummy input tensor (batch, channels, height, width).

  • threshold_low: If FLOPs-to-Memory ratio < this value, the layer is considered memory-bound.

  • threshold_high: If FLOPs-to-Memory ratio > this value, the layer is considered compute-bound.

  • mode: defult:raw

    • "raw" → Show every layer individually (Conv, BN, ReLU, etc.).(defult)

    • "cba" → Combine Conv+BN+Activation into a single CBA block.

    • "block" → Merge into large functional blocks (e.g., backbone, neck, head).

  • skip_bn: profiling with/without BatchNorm, defult:True

  • skip_act: profiling with/without Activation function, defult:True

  • skip_Sequential: profiling with/without Sequential, defult:True

Output:

The function prints a detailed table in the console, and also returns a list (layer_stats) containing:

  • Layer name
  • Input shape / Output shape
  • FLOPs
  • Memory usage
  • FLOPs-to-Memory ratio
  • Number of parameters
  • Tag (Memory-bound, Compute-bound, Balanced)

Example (mode='yaw'):

import torch
import torch.nn as nn
from model_profiler import (
    profile_flops_and_memory_layername,
    export_profile_to_excel,
    draw_model_with_tags
)

# Define a simple CNN
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU()
        self.pool  = nn.MaxPool2d(2)
        self.fc    = nn.Linear(16*16*16, 10)

    def forward(self, x):
        x = self.pool(self.relu1(self.bn1(self.conv1(x))))
        x = x.view(x.size(0), -1)
        return self.fc(x)

# Create model
model = SimpleCNN()

# Run profiler
stats = profile_flops_and_memory_layername(model, 
                                          input_size=(1, 3, 32, 32), 
                                          mode="row")

# Export to Excel
export_profile_to_excel(stats, "cnn_profile.xlsx")

# Draw graph (requires Graphviz installed)
draw_model_with_tags(model, (1, 3, 32, 32), stats, filename="cnn_graph")

📊 Output in console

Layer(Name) Input Shape Output Shape FLOPs (M) Memory (KB) FLOP/Byte Params (K) Tag
conv1 (Conv2d) (1, 3, 32, 32) (1, 16, 32, 32) 0.29 12.3 23.5 448 Balanced
bn1 (BatchNorm2d) (1, 16, 32, 32) (1, 16, 32, 32) 0.03 8.0 0.5 32 Memory-bound ❗
relu1 (ReLU) (1, 16, 32, 32) (1, 16, 32, 32) 0.02 8.0 0.2 0 Memory-bound ❗
pool (MaxPool2d) (1, 16, 32, 32) (1, 16, 16, 16) 0.01 4.0 0.1 0 Memory-bound ❗
fc (Linear) (1, 4096) (1, 10) 0.04 16.0 2.5 40 Balanced

📊 Excel Output Example
Excel檔案顯示

📊 Model structure as .png file

模型結構圖

Below is an example of profiling a model with estimate_inference_time :

latency = estimate_inference_time(stats, 
                                  compute_tops=1,   # 1-TOPS NPU
                                  mem_bw_gbs=1,     # 1 GB/s DRAM
                                  sram_size_mb=1)   # 1MB SRAM                              

Parameters

stats (list): Output from profile_flops_and_memory_layername. Contains per-layer FLOPs, memory usage, and other profiling data.

compute_tops (float): Peak compute performance of the target accelerator, in TOPS (Tera Operations Per Second). Example: 100 = 100 TOPS = 100e12 ops/sec.

mem_bw_gbs (float): External memory (DRAM) bandwidth, in GB/s.Example: 200 = 200 GB/s.

sram_size_mb (float): On-chip SRAM buffer size, in MB.Determines whether a layer can fit entirely on-chip (fast access) or must use DRAM (slower).

Returns

'latency (float): Estimated total inference latency (in milliseconds) for the given model on the target hardware. The estimation considers:

  • Compute-bound time = FLOPs / (compute_tops × 1e12)
  • Memory-bound time = Memory Bytes / (mem_bw_gbs × 1e9)
  • SRAM vs DRAM penalty (if working set > SRAM size).

Below is an example of profiling a model with export_profile_to_excel_withinferencetime:

export_profile_to_excel_withinferencetime(stats, 
                                          filename="estimatetime_profile_report.xlsx",
                                          compute_tops=1, 
                                          mem_bw_gbs=1, 
                                          sram_size_mb=1)

Parameters

stats (list): Profiling results from profile_flops_and_memory_layername.

filename (str): Path to the output Excel file. Example: "report.xlsx".

compute_tops (float): Target hardware compute capability in TOPS (same as above).

mem_bw_gbs (float): DRAM bandwidth in GB/s (same as above).

sram_size_mb (float): SRAM buffer size in MB (same as above).

Excel Output

This function creates an Excel file with:

  • Profile Report sheet

    • Layer name
    • Input / Output shape
    • FLOPs
    • Memory usage
    • Params
    • Estimated latency (ms) for each layer
    • Bound type (Memory-bound / Compute-bound)
    • Row colored (red = memory-bound, green = compute-bound, gray = balanced).
  • Statistics sheet

    • Count of layers by bound type
    • Total FLOPs, Params, Memory
    • Total estimated inference latency.

Example (mode='yaw'):

import torch
import torch.nn as nn
from model_profiler import profile_flops_and_memory_layername
from model_profiler import estimate_inference_time, export_profile_to_excel_withinferencetime

# === bulid a simple CNN model ===
class CBA(nn.Module):
    '''
    conv+BN+LeakyReLU
    '''

    def __init__(self,
                 in_channels,
                 out_channels,
                 k_size,
                 padding,
                 stride,
                 bias=False,
                 dilation=1):
        super(CBA, self).__init__()

        self.cba_unit = nn.Sequential(
                        nn.Conv2d(in_channels,
                                out_channels,
                                k_size,
                                padding=padding,
                                stride=stride,
                                bias=bias,
                                dilation=dilation), 
                        nn.BatchNorm2d(out_channels), 
                        nn.LeakyReLU()
                        )

    def forward(self, inputs):
        outputs = self.cba_unit(inputs)
        return outputs
    

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.conv3 = CBA(in_channels=32,
                        out_channels=64,
                        k_size=3,
                        padding=1,
                        stride=2)

        self.fc = nn.Linear(64 * 8 * 8, 10)  # suppose image size is 32x32

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

if __name__ == "__main__":
    model = SimpleCNN()

    # Profile model
    stats = profile_flops_and_memory_layername(
        model,
        input_size=(1, 3, 32, 32),
        mode="raw",   # raw / cba / block
        skip_bn=True,
        skip_act=True,
        skip_Sequential=True
    )

    # 假設一個 1 TOPS NPU + 1 GB/s DRAM + 1MB SRAM
    latency = estimate_inference_time(stats, compute_tops=1, mem_bw_gbs=1, sram_size_mb=1)
    export_profile_to_excel_withinferencetime(stats, filename="estimatetime_profile_report.xlsx",
                                             compute_tops=1, mem_bw_gbs=1, sram_size_mb=1)

📌 Roadmap

  • Add this repository to pip install

  • Add latency measurement

  • Add visualization in Jupyter Notebook

📌 Problems you may encounter

⚠️ You must also install the Graphviz system package (see Graphviz Setup)

IF there are any problems in draw_model_with_tags, please check Graphviz Executable.

⚠️ Graphviz Executable Not Found

If you see an error like this:

graphviz.backend.execute.ExecutableNotFound: failed to execute WindowsPath('dot'),
make sure the Graphviz executables are on your systems' PATH

✅ Cause

The Python graphviz package is only a binding. It requires the actual Graphviz executables (especially dot) to be installed on your system and available in the PATH.

🔧 Solution: Graphviz Setup

  1. Download Graphviz Get the Windows installer from the official site: 👉 https://graphviz.gitlab.io/_pages/Download/Download_windows.html Use the stable release installer (EXE).

  2. Install Graphviz The default installation path is usually:

C:\Program Files\Graphviz\bin
  1. Add Graphviz to PATH Open Windows search → type Environment Variables. Edit the System Environment Variables → Path. Add:
C:\Program Files\Graphviz\bin

Click OK to save and close.

  1. Restart your terminal (Anaconda Prompt / CMD / PowerShell) so the new PATH takes effect.

📄 License

MIT License © 2025 Chih-Sheng (Tommy) Huang

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_model_profiler-0.1.1.tar.gz (14.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_model_profiler-0.1.1-py3-none-any.whl (12.7 kB view details)

Uploaded Python 3

File details

Details for the file torch_model_profiler-0.1.1.tar.gz.

File metadata

  • Download URL: torch_model_profiler-0.1.1.tar.gz
  • Upload date:
  • Size: 14.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.7

File hashes

Hashes for torch_model_profiler-0.1.1.tar.gz
Algorithm Hash digest
SHA256 d5744772d1114b530d7de0f9826a596ee0292d7ff8079f0bb9b233a2d2af24c7
MD5 b1e91e7f231b227567e7eb33c0fc5d2c
BLAKE2b-256 e4343d9d79c811a69e0265adc9877e294064fdfb33aa688c59d23778541ed5e3

See more details on using hashes here.

File details

Details for the file torch_model_profiler-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_model_profiler-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 78fd71cde0aa146e51b39cc55be1542439f91d46558dc585eecf4fa8a4d0f902
MD5 2216c43bf54e9b49c25a1ca6d59e3cbb
BLAKE2b-256 542aca7e0698ba66af034880156e5e017f859a9df98b6a4dfdd65fb7b31f3cf6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page