A torch model analyzer
Project description
Torch Analyzer
This tool can be used to analyze the model run time, GPU memory usage, input/output information of each layer, and FLOPs of each layer in PyTorch models.
The run time, GPU memory usage, and FLOPs are analyzed at the cuda operator level, which is more accurate than the existing module-based analysis. Even model with custom operators can be analyzed.
Installation
Install from pip:
pip install torch-analyzer
Install from source:
git clone https://github.com/IrisRainbowNeko/torch-analyzer.git
cd torch-analyzer
pip install -e .
Usage
Example:
import torch
import torchvision.models as models
from torchanalyzer import ModelTimeMemAnalyzer, TorchViser
model = models.resnet18().cuda()
inputs = torch.randn(1, 3, 224, 224).cuda()
analyzer = ModelTimeMemAnalyzer(model)
info = analyzer.analyze(inputs)
TorchViser().show(model, info)
Analyze model
Analyze run time of each layer:
from torchanalyzer import ModelTimeMemAnalyzer
analyzer = ModelTimeMemAnalyzer(model)
info = analyzer.analyze(inputs)
Analyze input/output information of each layer:
from torchanalyzer import ModelIOAnalyzer
analyzer = ModelIOAnalyzer(model)
info = analyzer.analyze(inputs)
Analyze flops of each layer:
from torchanalyzer import ModelFlopsAnalyzer
analyzer = ModelFlopsAnalyzer(model)
info = analyzer.analyze(inputs)
Show Analyzed Information
Show with the style like print(model)
in torch:
from torchanalyzer import TorchViser
TorchViser().show(model, info)
Show with table style:
from torchanalyzer import TableViser
TableViser().show(model, info)
Show with flow style:
from torchanalyzer import FlowViser
FlowViser().show(model, info)
Backward Analyze
Analyze run time and memory of each layer in backward:
from torchanalyzer import ModelTimeMemAnalyzer
analyzer = ModelTimeMemAnalyzer(model)
info = analyzer.analyze(inputs, with_backward=True)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for torch_analyzer-1.4.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 028c6e4d02f425bf3fa07126368c296caf699c8141319c86c81f0b3f7fb1c95b |
|
MD5 | 4cf2fa68d9a3ff31033a44e27e7123e0 |
|
BLAKE2b-256 | e1258a91b96a541e100ece70378b077876c53a3b335492e7d559b8f87a807058 |