Skip to main content

Simplification of pruned models for accelerated inference

Project description

Simplify

tests

Simplification of pruned models for accelerated inference.

Installation

Simplify can be installed using pip:

pip3 install torch-simplify

or if you want to run the latest version of the code, you can install from git:

git clone https://github.com/EIDOSlab/simplify
cd simplify
pip3 install -r requirements.txt

Example usage

from torchvision.models import resnet18
from simplify import fuse

model = resnet18()
model.eval()
bn_folding = ...  # List of pairs (conv, bn) to fuse in a single layer
model = fuse(model, bn_folding)

Propagate

The propagate module is used to remove the non-zero bias from zeroed-out neurons in order to be able to remove them.

import torch
from simplify import propagate_bias
from torchvision.models import resnet18

zeros = torch.zeros(1, 3, 224, 224)
model = resnet18()
pinned_out = ...  # List of layers for which the bias should not be propagated
propagate_bias(model, zeros, pinned_out)

Remove

The remove module is used to remove actually remove the zeroed neurons from the model architecture.

import torch
from simplify import remove_zeroed
from torchvision.models import resnet18

zeros = torch.zeros(1, 3, 224, 224)
model = resnet18()
pinned_out = ...  # List of layers in which the output should not change shape
remove_zeroed(model, zeros, pinned_out)

Utilities

We also provide a set of utilities used to define bn_folding and pinned_out for standard PyTorch models.

from torchvision.models import resnet18
from utils import get_bn_folding, get_pinned_out

model = resnet18()
bn_folding = get_bn_folding(model)
pinned_out = get_pinned_out(model)
Tests

Inference time benchmarks

Evaluation mode (fuses BatchNorm)

Update timestamp 08/10/2021 14:26:25

Random structured pruning amount = 50.0%

Architecture Dense time Pruned time Simplified time
alexnet 7.58ms ± 0.29 7.55ms ± 0.28 2.95ms ± 0.02
densenet121 36.41ms ± 4.88 34.31ms ± 3.85 21.87ms ± 1.45
googlenet 15.44ms ± 3.19 13.68ms ± 0.09 10.31ms ± 0.82
inception_v3 25.29ms ± 7.31 21.68ms ± 2.90 13.22ms ± 2.23
mnasnet1_0 17.66ms ± 0.57 13.64ms ± 0.13 11.59ms ± 0.07
mobilenet_v3_large 13.74ms ± 0.67 12.18ms ± 0.46 11.95ms ± 0.21
resnet50 24.39ms ± 4.48 26.19ms ± 5.84 18.21ms ± 1.98
resnext101_32x8d 76.11ms ± 15.79 77.35ms ± 20.04 65.68ms ± 16.41
shufflenet_v2_x2_0 18.07ms ± 2.23 14.32ms ± 0.21 13.06ms ± 0.08
squeezenet1_1 4.50ms ± 0.06 4.39ms ± 0.05 4.09ms ± 0.50
vgg19_bn 40.41ms ± 12.13 38.56ms ± 10.72 12.39ms ± 0.19
wide_resnet101_2 79.40ms ± 25.57 82.86ms ± 22.47 60.16ms ± 10.77

Status of torchvision.models

:heavy_check_mark:: all good

:x:: gives different results

:cursing_face:: an exception occurred

:man_shrugging:: test skipped due to failing of the previous one

Fuse BatchNorm

Update timestamp 06/10/2021 20:26:15

Architecture BatchNorm Folding Bias Propagation Simplification
alexnet :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
densenet121 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
googlenet :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
inception_v3 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
mnasnet1_0 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
mobilenet_v3_large :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
resnet50 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
resnext101_32x8d :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
shufflenet_v2_x2_0 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
squeezenet1_1 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
vgg19_bn :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
wide_resnet101_2 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:

Keep BatchNorm

Update timestamp 06/10/2021 20:36:11

Architecture BatchNorm Folding Bias Propagation Simplification
alexnet :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
densenet121 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
googlenet :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
inception_v3 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
mnasnet1_0 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
mobilenet_v3_large :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
resnet50 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
resnext101_32x8d :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
shufflenet_v2_x2_0 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
squeezenet1_1 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
vgg19_bn :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
wide_resnet101_2 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-simplify-1.1.0.tar.gz (15.6 kB view details)

Uploaded Source

Built Distribution

torch_simplify-1.1.0-py3-none-any.whl (16.7 kB view details)

Uploaded Python 3

File details

Details for the file torch-simplify-1.1.0.tar.gz.

File metadata

  • Download URL: torch-simplify-1.1.0.tar.gz
  • Upload date:
  • Size: 15.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.8

File hashes

Hashes for torch-simplify-1.1.0.tar.gz
Algorithm Hash digest
SHA256 38f30c0f0ef25eb802541a0a747b6c96bd0bd48da1ea58437ed0705bb7c18be7
MD5 72f1dd90f5a6092201866dced341163d
BLAKE2b-256 88aaef36219182c74e6ec64e4812ec13986c207d6e65a4fdb6f6b8996ae8b096

See more details on using hashes here.

File details

Details for the file torch_simplify-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: torch_simplify-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 16.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.8

File hashes

Hashes for torch_simplify-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e7a9556d947ae382385c7b04f191bcdbd96e95f27401f6cbe300447ea873c1ea
MD5 6e54a622f8ef593f9a7a2dfb77dcde84
BLAKE2b-256 06d069c2af310e8ac41dda1f8d102041cd957b69bea70b2059ffc1b6237b05fb

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page