Skip to main content

Estimate FLOPs of neural networks

Project description

License Test Pytorch Flops Counter PyPI

pytorch-estimate-flops

Simple pytorch utility that estimates the number of FLOPs for a given network. For now only some basic operations are supported (basically the ones I needed for my models). More will be added soon.

All contributions are welcomed.

Installation

You can install the model using pip:

pip install pthflops

or directly from the github repository:

git clone https://github.com/1adrianb/pytorch-estimate-flops && cd pytorch-estimate-flops
python setup.py install

Note: pytorch 1.8 or newer is recommended.

Example

import torch
from torchvision.models import resnet18

from pthflops import count_ops

# Create a network and a corresponding input
device = 'cuda:0'
model = resnet18().to(device)
inp = torch.rand(1,3,224,224).to(device)

# Count the number of FLOPs
count_ops(model, inp)

Ignoring certain layers:

import torch
from torch import nn
from pthflops import count_ops

class CustomLayer(nn.Module):
    def __init__(self):
        super(CustomLayer, self).__init__()
        self.conv1 = nn.Conv2d(5, 5, 1, 1, 0)
        # ... other layers present inside will also be ignored

    def forward(self, x):
        return self.conv1(x)

# Create a network and a corresponding input
inp = torch.rand(1,5,7,7)
net = nn.Sequential(
    nn.Conv2d(5, 5, 1, 1, 0),
    nn.ReLU(inplace=True),
    CustomLayer()
)

# Count the number of FLOPs, jit mode:
count_ops(net, inp, ignore_layers=['CustomLayer'])

# Note: if you are using python 1.8 or newer with fx instead of jit, the naming convention changed. As such, you will have to pass ['_2_conv1']
# Please check your model definition to account for this.
# Count the number of FLOPs, fx mode:
count_ops(net, inp, ignore_layers=['_2_conv1'])

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pthflops-0.4.2.tar.gz (10.9 kB view hashes)

Uploaded source

Built Distributions

pthflops-0.4.2-py3-none-any.whl (11.1 kB view hashes)

Uploaded py3

pthflops-0.4.2-py2.py3-none-any.whl (11.1 kB view hashes)

Uploaded py2 py3

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Huawei Huawei PSF Sponsor Microsoft Microsoft PSF Sponsor NVIDIA NVIDIA PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page