Skip to main content

A pytorch toolkit for structured neural network pruning and automatic layer dependency maintaining.

Project description

Torch-Pruning

A pytorch toolkit for structured neural network pruning and layer dependency maintaining

This tool will automatically detect and handle layer dependencies (channel consistency) during pruning. It is able to handle various network architectures such as DenseNet, ResNet, and Inception. See examples/test_models.py for more supported models.

Supported Modules: Conv, Linear, BatchNorm, Transposed Conv, PReLU

Feel free to open a pull request if you have some interesting ideas!

How it works

This package will run your model with fake inputs and collect forward information just like torch.jit. Then a dependency graph is established to describe the computational graph. When a pruning function (e.g. torch_pruning.prune_conv ) is applied on certain layer through DependencyGraph.get_pruning_plan, this package will traverse the whole graph to fix inconsistent modules such as BN. The pruning index will be automatically mapped to correct position if there is torch.split or torch.cat in your model.

Tip: please remember to save the whole model object (weights+architecture) rather than model weights only:

# save a pruned model
# torch.save(model.state_dict(), 'model.pth') # weights only
torch.save(model, 'model.pth') # obj (arch) + weights

# load a pruned model
model = torch.load('model.pth') # no load_state_dict
Dependency Visualization Example
Conv-Conv AlexNet
Conv-FC (Global Pooling or Flatten) ResNet, VGG
Skip Connection ResNet
Concatenation DenseNet, ASPP
Split torch.chunk

Known Issues:

  • When groups>1, only depthwise conv is supported, i.e. groups=in_channels=out_channels.
  • Customized operations will be treated as element-wise op, e.g. subclass of torch.autograd.Function.

Installation

pip install torch_pruning # v0.2.4

Quickstart

A minimal example

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True)

# 1. setup strategy (L1 Norm)
strategy = tp.strategy.L1Strategy() # or tp.strategy.RandomStrategy()

# 2. build layer dependency for resnet18
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 3. get a pruning plan from the dependency graph.
pruning_idxs = strategy(model.conv1.weight, amount=0.4) # or manually selected pruning_idxs=[2, 6, 9]
pruning_plan = DG.get_pruning_plan( model.conv1, tp.prune_conv, idxs=pruning_idxs )
print(pruning_plan)

# 4. execute this plan (prune the model)
pruning_plan.exec()

Pruning the resnet.conv1 will affect several layers. Let's inspect the pruning plan (with pruning_idxs=[2, 6, 9]):

-------------
[ <DEP: prune_conv => prune_conv on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False))>, Index=[2, 6, 9], NumPruned=441]
[ <DEP: prune_conv => prune_batchnorm on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))>, Index=[2, 6, 9], NumPruned=6]
[ <DEP: prune_batchnorm => _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[ <DEP: _prune_elementwise_op => _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[ <DEP: _prune_elementwise_op => prune_related_conv on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))>, Index=[2, 6, 9], NumPruned=1728]
[ <DEP: _prune_elementwise_op => _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[ <DEP: _prune_elementwise_op => prune_batchnorm on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))>, Index=[2, 6, 9], NumPruned=6]
[ <DEP: prune_batchnorm => prune_conv on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))>, Index=[2, 6, 9], NumPruned=1728]
[ <DEP: _prune_elementwise_op => _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[ <DEP: _prune_elementwise_op => prune_related_conv on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))>, Index=[2, 6, 9], NumPruned=1728]
[ <DEP: _prune_elementwise_op => _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[ <DEP: _prune_elementwise_op => prune_batchnorm on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))>, Index=[2, 6, 9], NumPruned=6]
[ <DEP: prune_batchnorm => prune_conv on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))>, Index=[2, 6, 9], NumPruned=1728]
[ <DEP: _prune_elementwise_op => _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[ <DEP: _prune_elementwise_op => prune_related_conv on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False))>, Index=[2, 6, 9], NumPruned=3456]
[ <DEP: _prune_elementwise_op => prune_related_conv on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False))>, Index=[2, 6, 9], NumPruned=384]
11211 parameters will be pruned
-------------

Low-level pruning functions

In absence of DependencyGraph, We have to manually handle the broken dependencies layer by layer.

tp.prune_conv( model.conv1, idxs=[2,6,9] )

# fix the broken dependencies manually
tp.prune_batchnorm( model.bn1, idxs=[2,6,9] )
tp.prune_related_conv( model.layer2[0].conv1, idxs=[2,6,9] )
...

Customized Layers

Please refer to 'examples/customize_layer.py' for pruning customized layers with this package. A detailed tutorial is on the way!

Layer Dependency

During structured pruning, we need to maintain the channel consistency between different layers.

A Simple Case

More Complicated Cases

the layer dependency becomes much more complicated when the model contains skip connections or concatenations.

Residual Block:

Concatenation:

See paper Pruning Filters for Efficient ConvNets for more details.

Example: ResNet18 on Cifar10

1. Train the model

cd examples
python prune_resnet18_cifar10.py --mode train # 11.1M, Acc=0.9248

2. Pruning and fintuning

python prune_resnet18_cifar10.py --mode prune --round 1 --total_epochs 30 --step_size 20 # 4.5M, Acc=0.9229
python prune_resnet18_cifar10.py --mode prune --round 2 --total_epochs 30 --step_size 20 # 1.9M, Acc=0.9207
python prune_resnet18_cifar10.py --mode prune --round 3 --total_epochs 30 --step_size 20 # 0.8M, Acc=0.9176
python prune_resnet18_cifar10.py --mode prune --round 4 --total_epochs 30 --step_size 20 # 0.4M, Acc=0.9102
python prune_resnet18_cifar10.py --mode prune --round 5 --total_epochs 30 --step_size 20 # 0.2M, Acc=0.9011
...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-pruning-0.2.5.tar.gz (13.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_pruning-0.2.5-py3-none-any.whl (13.5 kB view details)

Uploaded Python 3

File details

Details for the file torch-pruning-0.2.5.tar.gz.

File metadata

  • Download URL: torch-pruning-0.2.5.tar.gz
  • Upload date:
  • Size: 13.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/52.0.0.post20210125 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.7.0

File hashes

Hashes for torch-pruning-0.2.5.tar.gz
Algorithm Hash digest
SHA256 11eadd2441cf5f1207f5cbaf4b56389d024a93310eddffb7cc59e2e41a34508d
MD5 3d8e16732cd1754cf0b4a6a0e7376eb9
BLAKE2b-256 d1aaced5875afd58771e0f931ccc2b16c48557560247f380d1e2f6254e52ab59

See more details on using hashes here.

File details

Details for the file torch_pruning-0.2.5-py3-none-any.whl.

File metadata

  • Download URL: torch_pruning-0.2.5-py3-none-any.whl
  • Upload date:
  • Size: 13.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/52.0.0.post20210125 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.7.0

File hashes

Hashes for torch_pruning-0.2.5-py3-none-any.whl
Algorithm Hash digest
SHA256 777aa9daa8046c91f0bf8da43f4222a6ee647db3c69eefda5d29eabdcd2233e6
MD5 d542bd3ba3497c1e8c36305031dcab0a
BLAKE2b-256 2eca13dc7cde6c453a717cbaa6b28466e413cd745f76beb478a05888f3876fa3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page