Skip to main content

A pytorch toolkit for neural network pruning and layer dependency maintaining.

Project description

Torch-Pruning

A pytorch toolkit for neural network pruning and layer dependency maintaining.

Torch-Pruning is dedicated to automatically detecting and maintaining the layer dependencies for structured pruning and providing reusable implementations. You can pay more attention to the design of pruning algorithms with the help of the dependency management.

This toolkit has the following features:

  • Basic pruning functions for Convolutional Neural Networks
  • Layer dependency management
  • Dependency customization for complex modules

Installation

pip install torch_pruning

Layer Dependency

A Simple Dependency

More Complicated Cases

the layer dependency becomes much more complicated when the model contains skip connections or concatenations.

Residual Block:

Concatenation:

See paper Pruning Filters for Efficient ConvNets for more details.

How It Works

Torch-Pruning provide a DependencyGraph to detect and manage the dependencies between layers. It requires a fake input to run the model and collect layer infomation from the dynamic computational graph. DependencyGraph.get_pruning_plan will detect the broken dependencies according to your pruning operation, and prepare a executable PruningPlan which contains all the required pruning operations.

Quickstart

Pruning with DependencyGraph

import torch
from torchvision.models import resnet18
import torch_pruning as pruning
model = resnet18(pretrained=True)
# build layer dependency for resnet18
DG = pruning.DependencyGraph( model, fake_input=torch.randn(1,3,224,224) )
# get a pruning plan according to the dependency graph. idxs is the index of pruned filters.
pruning_plan = DG.get_pruning_plan( model.conv1, pruning.prune_conv, idxs=[2, 6, 9] )
print(pruning_plan)
# execute this plan (prune the model)
pruning_plan.exec()

Pruning the resnet.conv1 will affect several modules. The pruning plan:

[ prune_conv on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), Indexs=[2, 6, 9], NumPruned=441]
[ prune_batchnorm on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Indexs=[2, 6, 9], NumPruned=6]
[ _prune_elementwise_op on elementwise (_ElementWiseOp()), Indexs=[2, 6, 9], NumPruned=0]
[ _prune_elementwise_op on elementwise (_ElementWiseOp()), Indexs=[2, 6, 9], NumPruned=0]
[ prune_related_conv on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), Indexs=[2, 6, 9], NumPruned=3456]
[ prune_batchnorm on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Indexs=[2, 6, 9], NumPruned=6]
[ prune_conv on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Indexs=[2, 6, 9], NumPruned=1728]
[ prune_related_conv on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), Indexs=[2, 6, 9], NumPruned=384]
[ prune_related_conv on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Indexs=[2, 6, 9], NumPruned=1728]
[ prune_batchnorm on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Indexs=[2, 6, 9], NumPruned=6]
[ prune_conv on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Indexs=[2, 6, 9], NumPruned=1728]
[ prune_related_conv on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Indexs=[2, 6, 9], NumPruned=1728]
11211 parameters will be pruned
-------------

Pruning with low-level pruning functions

You have to manually handle the dependencies between layers without DependencyGraph. See examples/example_pruning_fn.py for more details about pruning functions.

pruning.prune_conv( model.conv1, idxs=[2,6,9] )

# fix the broken dependencies
pruning.prune_batchnorm( model.bn1, idxs=[2,6,9] )
pruning.prune_related_conv( model.layer2[0].conv1, idxs=[2,6,9] )
...

Example: ResNet18 on Cifar10

1. Train the model

cd examples
python prune_resnet18.py --mode train # 11.1M, Acc=0.9248

2. Pruning and fintuning

python prune_resnet18.py --mode prune --round 1 --total_epochs 30 --step_size 20 # 4.5M, Acc=0.9229
python prune_resnet18.py --mode prune --round 2 --total_epochs 30 --step_size 20 # 1.9M, Acc=0.9207
python prune_resnet18.py --mode prune --round 3 --total_epochs 30 --step_size 20 # 0.8M, Acc=0.9176
python prune_resnet18.py --mode prune --round 4 --total_epochs 30 --step_size 20 # 0.4M, Acc=0.9102
python prune_resnet18.py --mode prune --round 5 --total_epochs 30 --step_size 20 # 0.2M, Acc=0.9011
...

TODO

  • Documents
  • Predefined pruning algorithms
  • Test the toolkit with Densenet / MobileNet / ...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_pruning-0.1.1.tar.gz (10.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_pruning-0.1.1-py3-none-any.whl (10.3 kB view details)

Uploaded Python 3

File details

Details for the file torch_pruning-0.1.1.tar.gz.

File metadata

  • Download URL: torch_pruning-0.1.1.tar.gz
  • Upload date:
  • Size: 10.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/42.0.2 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.7.4

File hashes

Hashes for torch_pruning-0.1.1.tar.gz
Algorithm Hash digest
SHA256 89057a93a5d2711d4814047b4c9cd1e3f94023f416de28870691afb5972cdd8d
MD5 f30038301af7552a31fccbfb07506e65
BLAKE2b-256 c71afdc90be2040bae84a02ab29fae9dc9a12130fc744f10bb722d2d0f904db7

See more details on using hashes here.

File details

Details for the file torch_pruning-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: torch_pruning-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 10.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/42.0.2 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.7.4

File hashes

Hashes for torch_pruning-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 60e4aa1b8a85398a11f2ce598dc0a9967907705e558cb6b00a3ffadd0a72db77
MD5 1f2b37045532950e50eb5c6159ce499e
BLAKE2b-256 949e3d95c4e8fe0eb4b87dc8b2755eba03453b4605e3e64ee07ecf7352e118f8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page