Simplification of pruned models for accelerated inference
Project description
Simplify
Simplification of pruned models for accelerated inference.
Installation
Simplify can be installed using pip:
pip3 install torch-simplify
or if you want to run the latest version of the code, you can install from git:
git clone https://github.com/EIDOSlab/simplify
cd simplify
pip3 install -r requirements.txt
Example usage
from torchvision.models import resnet18
from simplify import fuse
model = resnet18()
model.eval()
bn_folding = ... # List of pairs (conv, bn) to fuse in a single layer
model = fuse(model, bn_folding)
Propagate
The propagate module is used to remove the non-zero bias from zeroed-out neurons in order to be able to remove them.
import torch
from simplify import propagate_bias
from torchvision.models import resnet18
zeros = torch.zeros(1, 3, 224, 224)
model = resnet18()
pinned_out = ... # List of layers for which the bias should not be propagated
propagate_bias(model, zeros, pinned_out)
Remove
The remove module is used to remove actually remove the zeroed neurons from the model architecture.
import torch
from simplify import remove_zeroed
from torchvision.models import resnet18
zeros = torch.zeros(1, 3, 224, 224)
model = resnet18()
pinned_out = ... # List of layers in which the output should not change shape
remove_zeroed(model, zeros, pinned_out)
Utilities
We also provide a set of utilities used to define bn_folding
and pinned_out
for standard PyTorch models.
from torchvision.models import resnet18
from utils import get_bn_folding, get_pinned_out
model = resnet18()
bn_folding = get_bn_folding(model)
pinned_out = get_pinned_out(model)
Tests
Inference time benchmarks
Evaluation mode (fuses BatchNorm)
Update timestamp 08/10/2021 14:26:25
Random structured pruning amount = 50.0%
Architecture | Dense time | Pruned time | Simplified time |
---|---|---|---|
alexnet | 7.58ms ± 0.29 | 7.55ms ± 0.28 | 2.95ms ± 0.02 |
densenet121 | 36.41ms ± 4.88 | 34.31ms ± 3.85 | 21.87ms ± 1.45 |
googlenet | 15.44ms ± 3.19 | 13.68ms ± 0.09 | 10.31ms ± 0.82 |
inception_v3 | 25.29ms ± 7.31 | 21.68ms ± 2.90 | 13.22ms ± 2.23 |
mnasnet1_0 | 17.66ms ± 0.57 | 13.64ms ± 0.13 | 11.59ms ± 0.07 |
mobilenet_v3_large | 13.74ms ± 0.67 | 12.18ms ± 0.46 | 11.95ms ± 0.21 |
resnet50 | 24.39ms ± 4.48 | 26.19ms ± 5.84 | 18.21ms ± 1.98 |
resnext101_32x8d | 76.11ms ± 15.79 | 77.35ms ± 20.04 | 65.68ms ± 16.41 |
shufflenet_v2_x2_0 | 18.07ms ± 2.23 | 14.32ms ± 0.21 | 13.06ms ± 0.08 |
squeezenet1_1 | 4.50ms ± 0.06 | 4.39ms ± 0.05 | 4.09ms ± 0.50 |
vgg19_bn | 40.41ms ± 12.13 | 38.56ms ± 10.72 | 12.39ms ± 0.19 |
wide_resnet101_2 | 79.40ms ± 25.57 | 82.86ms ± 22.47 | 60.16ms ± 10.77 |
Status of torchvision.models
:heavy_check_mark:: all good
:x:: gives different results
:cursing_face:: an exception occurred
:man_shrugging:: test skipped due to failing of the previous one
Fuse BatchNorm
Update timestamp 06/10/2021 20:26:15
Architecture | BatchNorm Folding | Bias Propagation | Simplification |
---|---|---|---|
alexnet | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
densenet121 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
googlenet | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
inception_v3 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
mnasnet1_0 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
mobilenet_v3_large | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
resnet50 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
resnext101_32x8d | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
shufflenet_v2_x2_0 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
squeezenet1_1 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
vgg19_bn | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
wide_resnet101_2 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
Keep BatchNorm
Update timestamp 06/10/2021 20:36:11
Architecture | BatchNorm Folding | Bias Propagation | Simplification |
---|---|---|---|
alexnet | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
densenet121 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
googlenet | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
inception_v3 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
mnasnet1_0 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
mobilenet_v3_large | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
resnet50 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
resnext101_32x8d | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
shufflenet_v2_x2_0 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
squeezenet1_1 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
vgg19_bn | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
wide_resnet101_2 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch-simplify-1.1.0.tar.gz
.
File metadata
- Download URL: torch-simplify-1.1.0.tar.gz
- Upload date:
- Size: 15.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 38f30c0f0ef25eb802541a0a747b6c96bd0bd48da1ea58437ed0705bb7c18be7 |
|
MD5 | 72f1dd90f5a6092201866dced341163d |
|
BLAKE2b-256 | 88aaef36219182c74e6ec64e4812ec13986c207d6e65a4fdb6f6b8996ae8b096 |
File details
Details for the file torch_simplify-1.1.0-py3-none-any.whl
.
File metadata
- Download URL: torch_simplify-1.1.0-py3-none-any.whl
- Upload date:
- Size: 16.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e7a9556d947ae382385c7b04f191bcdbd96e95f27401f6cbe300447ea873c1ea |
|
MD5 | 6e54a622f8ef593f9a7a2dfb77dcde84 |
|
BLAKE2b-256 | 06d069c2af310e8ac41dda1f8d102041cd957b69bea70b2059ffc1b6237b05fb |