Skip to main content

Simplification of pruned models for accelerated inference

Project description

Simplify

tests

Simplification of pruned models for accelerated inference.

Installation

Simplify can be installed using pip:

pip3 install torch-simplify

or if you want to run the latest version of the code, you can install from git:

git clone https://github.com/EIDOSlab/simplify
cd simplify
pip3 install -r requirements.txt

Example usage

from torchvision.models import resnet18
from simplify import fuse

model = resnet18()
model.eval()
bn_folding = ...  # List of pairs (conv, bn) to fuse in a single layer
model = fuse(model, bn_folding)

Propagate

The propagate module is used to remove the non-zero bias from zeroed-out neurons in order to be able to remove them.

import torch
from simplify import propagate_bias
from torchvision.models import resnet18

zeros = torch.zeros(1, 3, 224, 224)
model = resnet18()
pinned_out = ...  # List of layers for which the bias should not be propagated
propagate_bias(model, zeros, pinned_out)

Remove

The remove module is used to remove actually remove the zeroed neurons from the model architecture.

import torch
from simplify import remove_zeroed
from torchvision.models import resnet18

zeros = torch.zeros(1, 3, 224, 224)
model = resnet18()
pinned_out = ...  # List of layers in which the output should not change shape
remove_zeroed(model, zeros, pinned_out)

Utilities

We also provide a set of utilities used to define bn_folding and pinned_out for standard PyTorch models.

from torchvision.models import resnet18
from utils import get_bn_folding, get_pinned_out

model = resnet18()
bn_folding = get_bn_folding(model)
pinned_out = get_pinned_out(model)
Tests

Inference time benchmarks

Evaluation mode (fuses BatchNorm)

Update timestamp 08/10/2021 14:26:25

Random structured pruning amount = 50.0%

Architecture Dense time Pruned time Simplified time
alexnet 7.58ms ± 0.29 7.55ms ± 0.28 2.95ms ± 0.02
densenet121 36.41ms ± 4.88 34.31ms ± 3.85 21.87ms ± 1.45
googlenet 15.44ms ± 3.19 13.68ms ± 0.09 10.31ms ± 0.82
inception_v3 25.29ms ± 7.31 21.68ms ± 2.90 13.22ms ± 2.23
mnasnet1_0 17.66ms ± 0.57 13.64ms ± 0.13 11.59ms ± 0.07
mobilenet_v3_large 13.74ms ± 0.67 12.18ms ± 0.46 11.95ms ± 0.21
resnet50 24.39ms ± 4.48 26.19ms ± 5.84 18.21ms ± 1.98
resnext101_32x8d 76.11ms ± 15.79 77.35ms ± 20.04 65.68ms ± 16.41
shufflenet_v2_x2_0 18.07ms ± 2.23 14.32ms ± 0.21 13.06ms ± 0.08
squeezenet1_1 4.50ms ± 0.06 4.39ms ± 0.05 4.09ms ± 0.50
vgg19_bn 40.41ms ± 12.13 38.56ms ± 10.72 12.39ms ± 0.19
wide_resnet101_2 79.40ms ± 25.57 82.86ms ± 22.47 60.16ms ± 10.77

Status of torchvision.models

:heavy_check_mark:: all good

:x:: gives different results

:cursing_face:: an exception occurred

:man_shrugging:: test skipped due to failing of the previous one

Fuse BatchNorm

Update timestamp 06/10/2021 20:26:15

Architecture BatchNorm Folding Bias Propagation Simplification
alexnet :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
densenet121 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
googlenet :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
inception_v3 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
mnasnet1_0 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
mobilenet_v3_large :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
resnet50 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
resnext101_32x8d :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
shufflenet_v2_x2_0 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
squeezenet1_1 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
vgg19_bn :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
wide_resnet101_2 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:

Keep BatchNorm

Update timestamp 06/10/2021 20:36:11

Architecture BatchNorm Folding Bias Propagation Simplification
alexnet :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
densenet121 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
googlenet :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
inception_v3 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
mnasnet1_0 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
mobilenet_v3_large :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
resnet50 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
resnext101_32x8d :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
shufflenet_v2_x2_0 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
squeezenet1_1 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
vgg19_bn :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
wide_resnet101_2 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-simplify-1.1.0.tar.gz (15.6 kB view hashes)

Uploaded Source

Built Distribution

torch_simplify-1.1.0-py3-none-any.whl (16.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page