Skip to main content

Count the layer-wise MACs and number of parameters of your PyTorch model.

Project description

pytorch-op-counter-layerwise

This repository provides thoplw that is a ython module to compute MACs (multiply–accumulate operations) and number of parameters for each layer of neural network models implemented by PyTorch.

Installation

Dependencies

The thoplw module requires:

  • PyTorch >= 2.0.0 (will work on the older version, but not checked)

and the sample code require:

  • Torchvision >= 0.15.0

User installation

pip install thoplw

Usage

Minimal example

import torch, torchvision, thoplw

# Instanciate the target model.
model = torchvision.models.resnet18()

# Compute MACs, number of parameters, and details of each layer.
macs, params, layers_info = thoplw.profile(model, tensor=torch.randn(1, 3, 224, 224))

# Print the total MACs and number of parameters.
print("Total MACs and params:")
print("  - Macs   =", macs)
print("  - Params =", params)
print()

# Print layer details.
print(layers_info.summary())

Running the above code will give you the output below (the table has been partially omitted because it is too long to show everything in this README).

Total MACs and params:
  - Macs   = 1824010216                                                                                                                                                       
  - Params = 11699112

| Name                  | Class             | Input shape    | Output shape   | MACs       | Params   |
+-----------------------+-------------------+----------------+----------------+------------+----------+
| conv1                 | Conv2d            | 3 x 224 x 224  | 64 x 112 x 112 | 118013952  | 9408     |
| bn1                   | BatchNorm2d       | 64 x 112 x 112 | 64 x 112 x 112 | 3211264    | 256      |
| relu                  | ReLU              | 64 x 112 x 112 | 64 x 112 x 112 | 0          | 0        |
| maxpool               | MaxPool2d         | 64 x 112 x 112 | 64 x 56 x 56   | 0          | 0        |
...
| layer4.1.conv2        | Conv2d            | 512 x 7 x 7    | 512 x 7 x 7    | 115605504  | 2359296  |
| layer4.1.bn2          | BatchNorm2d       | 512 x 7 x 7    | 512 x 7 x 7    | 100352     | 2048     |
| avgpool               | AdaptiveAvgPool2d | 512 x 7 x 7    | 512 x 1 x 1    | 1024       | 0        |
| fc                    | Linear            | 512            | 1000           | 513000     | 513000   |
+-----------------------+-------------------+----------------+----------------+------------+----------+
| Total                 | ResNet            | 3 x 224 x 224  | 1000           | 1824010216 | 11699112 |

Clever formatting

The thoplw provides clever_format function that returns appropriate expressions of the given numbers likewise thop package.

macs, params = thoplw.clever_format([macs, params], "%.3f")
print("Total MACs and params:")
print("  - Macs   =", macs)
print("  - Params =", params)

The table of layers detail supports three types of formatting, raw number (default choice), clever format like the clever_format function, and ratio.

# Print the table with clever formatting.
print(layers_info.summary(kind="text", fmt="clever"))

# Print the table with ratio formatting.
print(layers_info.summary(kind="text", fmt="ratio"))

Table type

The above example prints the NN model summary as a plain text, but you can dump the summary as CSV and Markdown format too. The following is an example to save the table as CSV and Markdown format respectively.

# Save as CSV format.
with open("summary.csv") as ofp:
    ofp.write(layers_info.summary(kind="csv"))

# Save as Markdown format.
with open("summary.md") as ofp:
    ofp.write(layers_info.summary(kind="md"))

API reference

thoplw.profile

macs, params, layers_info = thoplw.profile(model: torch.nn.Module,
                                           tensor: torch.Tensor,
                                           custom_ops: dict = {},
                                           verbose: bool = True)

Computes MADDs and the number of parameters.

  • Args
    • model: the target NN model.
    • tensor: input tensor for the model.
    • custom_ops: optional custom operations.
    • verbose: shows extra message on your terminal if True.
  • Returns
    • macs: total MADDs of the target model and the given input tensor.
    • params: number of parameters of the target model.
    • layer_info: LayerInfo class that store the details of each layer.

thoplw.clever_format

formatted_values = thoplw.clever_format(values: int or list, fmt: str = "%6.2f")

Returns formatted string of the given integer(s).

  • Args
    • values: input value, or values.
    • fmt: format specifier.
  • Returns
    • formatted_values: format result.

LayerInfo class

class LayerInfo:
    ...
    def summary(self,
                kind: str = "text",
                fmt: str = "raw") -> str:
    ...

A class to store layer details. Only the summary function is opened to users.

  • Args
    • kind: table type to be returned ("text" means simple table, "csv" means CSV, and "md" means Markdown).
    • fmt: output format ("raw" means raw integer, "clever" means auto formatting, and "ratio" means ratio format).
  • Returns
    • formatted string.

Results of Recent Models

The following results can be obtained by running tests/test_benchmarks.py. Click the model anem to see the layer details.

Model name Params [M] MACs [G] Model name Params [M] MACs [G]
alexnet 714.22 M 61.10 M resnext50_32x4d 4.29 G 25.10 M
vgg11 7.61 G 132.86 M resnext101_32x8d 16.54 G 88.99 M
vgg11_bn 7.64 G 132.87 M densenet121 2.90 G 8.06 M
vgg13 11.31 G 133.05 M densenet161 7.85 G 28.90 M
vgg13_bn 11.36 G 133.06 M densenet169 3.44 G 14.31 M
vgg16 15.47 G 138.36 M densenet201 4.39 G 20.24 M
vgg16_bn 15.52 G 138.37 M googlenet 1.51 G 6.64 M
vgg19 19.63 G 143.67 M inception_v3 5.75 G 23.87 M
vgg19_bn 19.69 G 143.69 M squeezenet1.0 818.93 M 1.25 M
resnet18 1.82 G 11.70 M squeezenet1.1 349.16 M 1.24 M
resnet34 3.68 G 21.81 M mobilenet_v2 327.49 M 3.54 M
resnet50 4.13 G 25.61 M mobilenet_v3_small 62.17 M 2.55 M
resnet101 7.87 G 44.65 M mobilenet_v3_large 234.21 M 5.51 M
resnet152 11.60 G 60.34 M shufflenet_v2_x0.5 44.57 M 1.37 M
wide_resnet50_2 22.84 G 127.02 M shufflenet_v2_x1.0 152.71 M 2.29 M
wide_resnet101_2 11.46 G 68.95 M mnasnet_0.5 116.72 M 2.24 M
mnasnet_1.0 336.24 M 4.42 M

Gratitude

  • Developers and maintainers of pytorch-OpCounter. The author learned a lot from the repository.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

thoplw-2024.1.25.tar.gz (67.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

thoplw-2024.1.25-py3-none-any.whl (13.4 kB view details)

Uploaded Python 3

File details

Details for the file thoplw-2024.1.25.tar.gz.

File metadata

  • Download URL: thoplw-2024.1.25.tar.gz
  • Upload date:
  • Size: 67.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.10

File hashes

Hashes for thoplw-2024.1.25.tar.gz
Algorithm Hash digest
SHA256 8ce151b14bcebc0d28b8c4d7457504db639107526b4f68f79fbbe4f1b7163c6d
MD5 7b4eccec4a8d185d5461b6d787b838ab
BLAKE2b-256 c81493c05ed33bc90e15050e3eff3472f8e61b4087ec160449d04b6875592ec7

See more details on using hashes here.

File details

Details for the file thoplw-2024.1.25-py3-none-any.whl.

File metadata

  • Download URL: thoplw-2024.1.25-py3-none-any.whl
  • Upload date:
  • Size: 13.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.10

File hashes

Hashes for thoplw-2024.1.25-py3-none-any.whl
Algorithm Hash digest
SHA256 a820b04d58846cc5ae0cfbf7d6aadfc911edf2ef650ae4c58baa9f44bb4baaac
MD5 6837f9159d63656a959e2a10b3325245
BLAKE2b-256 8e0981701b5a4dbec3409ba9422da000ce702cfd0eeb1b0dbe2f37d3d905c3bc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page