Estimate FLOPs of neural networks
Project description
pytorch-estimate-flops
Simple pytorch utility that estimates the number of FLOPs for a given network. For now only some basic operations are supported (basically the ones I needed for my models). More will be added soon.
All contributions are welcomed.
Installation
You can install the model using pip:
pip install pthflops
or directly from the github repository:
git clone https://github.com/1adrianb/pytorch-estimate-flops && cd pytorch-estimate-flops
python setup.py install
Note: pytorch 1.8 or newer is recommended.
Example
import torch
from torchvision.models import resnet18
from pthflops import count_ops
# Create a network and a corresponding input
device = 'cuda:0'
model = resnet18().to(device)
inp = torch.rand(1,3,224,224).to(device)
# Count the number of FLOPs
count_ops(model, inp)
Ignoring certain layers:
import torch
from torch import nn
from pthflops import count_ops
class CustomLayer(nn.Module):
def __init__(self):
super(CustomLayer, self).__init__()
self.conv1 = nn.Conv2d(5, 5, 1, 1, 0)
# ... other layers present inside will also be ignored
def forward(self, x):
return self.conv1(x)
# Create a network and a corresponding input
inp = torch.rand(1,5,7,7)
net = nn.Sequential(
nn.Conv2d(5, 5, 1, 1, 0),
nn.ReLU(inplace=True),
CustomLayer()
)
# Count the number of FLOPs, jit mode:
count_ops(net, inp, ignore_layers=['CustomLayer'])
# Note: if you are using python 1.8 or newer with fx instead of jit, the naming convention changed. As such, you will have to pass ['_2_conv1']
# Please check your model definition to account for this.
# Count the number of FLOPs, fx mode:
count_ops(net, inp, ignore_layers=['_2_conv1'])
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pthflops-0.4.2.tar.gz
(10.9 kB
view details)
Built Distributions
pthflops-0.4.2-py3-none-any.whl
(11.1 kB
view details)
File details
Details for the file pthflops-0.4.2.tar.gz
.
File metadata
- Download URL: pthflops-0.4.2.tar.gz
- Upload date:
- Size: 10.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1a64b6d75937e01cf837e3cdc688de1e0fb58a7d6105974956c3bbeaa1c105e8 |
|
MD5 | c0e226b78a267677a035bb57fc1a44d8 |
|
BLAKE2b-256 | 51e9610f95556b06a015e70e37bbeee80f8a57e87e5571c6be6c9901551532a1 |
File details
Details for the file pthflops-0.4.2-py3-none-any.whl
.
File metadata
- Download URL: pthflops-0.4.2-py3-none-any.whl
- Upload date:
- Size: 11.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7211664c4d47032c6859a84d14d8ce3ccd17ac08028389b92c8d66bf6ca2580c |
|
MD5 | 298a671a1ecf4b0b7eafe04ede1ae5ca |
|
BLAKE2b-256 | 0c2783eab6b0a3068944d5e52cd64de9339fce58b2c8e33fec08bafa89843881 |
File details
Details for the file pthflops-0.4.2-py2.py3-none-any.whl
.
File metadata
- Download URL: pthflops-0.4.2-py2.py3-none-any.whl
- Upload date:
- Size: 11.1 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8551ca3b10538cd6250b3cee8d42ced6c7cc3b955b57909feb849b95d59c45aa |
|
MD5 | b46ec23f9310e0c7c727027ed6fe69ba |
|
BLAKE2b-256 | a0f747983544d6bc6ae37156c1b5a2b5de5ed886bbdeeda9e1eca84b5abd55cf |