Structural Pruning for Model Acceleration.
Project description
Torch-Pruning
Structural Pruning for Model Acceleration
Torch-Pruning is a general-purpose library for structural network pruning, which supports a large variaty of nerual networks like Vision Transformers, ResNet, DenseNet, RegNet, ResNext, FCN, DeepLab, VGG, etc. Please refer to tests/test_torchvision_models.py for more details about prunable models.
Features:
- Channel pruning for CNNs (e.g. ResNet, DenseNet, Deeplab) and Transformers (e.g. ViT)
- High-level pruners: MagnitudePruner, BNScalePruner, GroupPruner, etc.
- Graph Tracing and dependency fixing.
- Supported modules: Conv, Linear, BatchNorm, LayerNorm, Transposed Conv, PReLU, Embedding, MultiheadAttention, nn.Parameters and customized modules.
- Supported operations: split, concatenation, skip connection, flatten, etc.
- Pruning strategies: Random, L1, L2, etc.
- Low-level pruning functions
- Benchmarks and tutorials
Plans:
- More high-level pruners like FisherPruner, SoftPruner, GeometricPruner, etc.
- Support more Transformers like Vision Transformers (:heavy_check_mark:), Swin Transformers, PoolFormers.
- Pruning benchmarks for CIFAR and ImageNet.
- A paper about this repo (:heavy_check_mark:, will be released ASAP)
How it works
Torch-Pruning will forward your model with a fake inputs and trace the computational graph just like torch.jit
. A dependency graph will be established to record the relation coupling between layers. Torch-pruning will collect all affected layers according by propogating your pruning operations through the whole graph, and then return a PruningClique
for pruning. All pruning indices will be automatically transformed if there are operations like torch.split
or torch.cat
.
Installation
git clone https://github.com/VainF/Torch-Pruning.git
Quickstart
Here we provide a quick start for Torch-Pruning. More explained details can be found in tutorals
0. Dependency
Dependency | Visualization | Example |
---|---|---|
Conv-Conv | AlexNet | |
Conv-FC (Global Pooling or Flatten) | ResNet, VGG | |
Skip Connection | ResNet | |
Concatenation | DenseNet, ASPP | |
Split | torch.chunk |
1. A minimal example
import torch
from torchvision.models import resnet18
import torch_pruning as tp
model = resnet18(pretrained=True).eval()
# 1. build dependency graph for resnet18
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))
# 2. Select channels for pruning, here we prune the channels indexed by [2, 6, 9].
pruning_idxs = [2, 6, 9]
pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )
# 3. prune all grouped layer that is coupled with model.conv1
if DG.check_pruning_group(pruning_group):
pruning_group.exec()
# 4. save & load the pruned model
torch.save(model, 'model.pth') # save the model object
model_loaded = torch.load('model.pth') # no load_state_dict
In this example, pruning resnet.conv1 will affect several layers. Let's inspect the pruning group (with pruning_idxs=[2, 6, 9]):
--------------------------------
Pruning Group
--------------------------------
[0] [DEP] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), #Pruned=3
[1] [DEP] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #Pruned=3
[2] [DEP] prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp(ReluBackward0), #Pruned=3
[3] [DEP] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0), #Pruned=3
[4] [DEP] prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp(AddBackward0), #Pruned=3
[5] [DEP] prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #Pruned=3
[6] [DEP] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #Pruned=3
[7] [DEP] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on _ElementWiseOp(ReluBackward0), #Pruned=3
[8] [DEP] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_out_channels on _ElementWiseOp(AddBackward0), #Pruned=3
[9] [DEP] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #Pruned=3
[10] [DEP] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #Pruned=3
[11] [DEP] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on _ElementWiseOp(ReluBackward0), #Pruned=3
[12] [DEP] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(61, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), #Pruned=3
[13] [DEP] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(61, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), #Pruned=3
[14] [DEP] prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #Pruned=3
[15] [DEP] prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #Pruned=3
--------------------------------
2. High-level Pruners
We provide some model-level pruners in this repo for convenience. You can specify the channel sparsity to prune the whole model and fintune it using your own training code. Please refer to tests/test_pruner.py for more details. More examples can be found in benchmarks/main.py.
import torch
from torchvision.models import resnet18
import torch_pruning as tp
model = resnet18(pretrained=True)
# Global metrics
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)
ignored_layers = []
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m) # DO NOT prune the final classifier!
iterative_steps = 5
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs,
importance=imp,
iterative_steps=iterative_steps,
ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
ignored_layers=ignored_layers,
)
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
pruner.step()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
# finetune your model here
# finetune(model)
# ...
3. Low-level pruning functions
You can also try to prune your model manually with low-level functions.
tp.prune_conv_out_channel( model.conv1, idxs=[2,6,9] )
# fix the broken dependencies manually
tp.prune_batchnorm( model.bn1, idxs=[2,6,9] )
tp.prune_conv_in_channel( model.layer2[0].conv1, idxs=[2,6,9] )
...
The following pruning functions are available:
tp.prune_conv_in_channel
tp.prune_conv_out_channel
tp.prune_depthwise_conv_out_channels
tp.prune_batchnorm
tp.prune_linear_in_channel
tp.prune_linear_out_channel
tp.prune_prelu
tp.prune_layernorm
tp.prune_embedding
tp.prune_parameter
tp.prune_multihead_attention
4. Customized Layers
Please refer to tests/test_customized_layer.py.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch-pruning-1.0.0.tar.gz
.
File metadata
- Download URL: torch-pruning-1.0.0.tar.gz
- Upload date:
- Size: 28.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
cab9a3a2f80cbec5b8c0523fe10fcf14e73c6ca02477c9cadcb42dbc06bdd0e7
|
|
MD5 |
a4465c62c4e16a5fe1a716f1c4676ed2
|
|
BLAKE2b-256 |
a0ab67d662957bf8c154e329b1a9c274c939b8f5212d3ca10bcfde4a66ab45bf
|
File details
Details for the file torch_pruning-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: torch_pruning-1.0.0-py3-none-any.whl
- Upload date:
- Size: 31.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
0d94574f67718baa51d3345dc7992e594909ba509c98754c0f17f2b7f374214d
|
|
MD5 |
d10dd96bda4f68bcbe3d3dbb1dd19877
|
|
BLAKE2b-256 |
d1cafaea370e3236d912f20e684886a265fc6435928a1f0687d5debecef7365f
|