Skip to main content

Structural Pruning for Model Acceleration.

Project description

Torch-Pruning

Towards Any Structural Pruning

Torch-Pruning (TP) is a versatile library that enables structural network pruning for a wide range of neural networks, including Vision Transformers, Yolov7, FasterRCNN, SSD, ResNet, DenseNet, ConvNext, RegNet, ResNext, FCN, DeepLab, VGG, etc. Different from torch.nn.utils.prune that zeroizes parameters through masking, Torch-Pruning employs a (non-deep) graph algorithm called DepGraph to physically remove coupled parameters (channels) from models. To explore more prunable models, please refer to benchmarks/prunability. So far, TP is compatible with 73/85=85.8% models from Torchvision 0.13.1. In this repo, a resource list for practical structural pruning is continuesly being updated.

For more technical details, please refer to our preprint paper:

DepGraph: Towards Any Structural Pruning
Gongfan Fang, Xinyin Ma, Mingli Song, Michael Bi Mi, Xinchao Wang

Please do not hesitate to open a discussion or issue if you encounter any problems with the library or have any questions related to the paper. We are always happy to assist you and address any concerns you may have.

Features:

  • Structural (Channel) pruning for CNNs (e.g. ResNet, DenseNet, Deeplab), Transformers (e.g. ViT) and Detectors (e.g. Yolov7, FasterRCNN, SSD)
  • High-level pruners: MagnitudePruner, BNScalePruner, GroupPruner (a simple pruner used in our paper), RandomPruner, etc.
  • Computational Graph Tracing and Dependency Modeling.
  • Supported modules: Conv, Linear, BatchNorm, LayerNorm, Transposed Conv, PReLU, Embedding, MultiheadAttention, nn.Parameters and customized modules.
  • Supported operations: split, concatenation, skip connection, flatten, reshape, view, all element-wise ops, etc.
  • Low-level pruning functions
  • Benchmarks and tutorials
  • A resource list for practical structrual pruning.
  • Automatical pruning of unwrapped nn.Parameter that does not belong to any standard layers or ops.

Plans:

We have a wealth of ideas, but unfortunately, only a handful of contributors at the moment. We hope to attract more talented guys to join us in bringing these ideas to fruition and making Torch-Pruning a practical library.

  • A benchmark for Torchvision compatibility (73/85=85.8, :heavy_check_mark:) and timm compatibility.
  • More Detectors (We are working on the pruning of YOLO series such as YOLOv7 :heavy_check_mark:, YOLOv8)
  • Pruning from Scratch / at Initialization.
  • Language, Speech and Generative Models.
  • More high-level pruners like FisherPruner, GrowingReg, etc.
  • More standard layers: GroupNorm, InstanceNorm, Shuffle Layers, etc.
  • More Transformers like Vision Transformers (:heavy_check_mark:), Swin Transformers, PoolFormers.
  • Block/Layer/Depth Pruning
  • Pruning benchmarks for CIFAR, ImageNet and COCO.

Installation

pip install torch-pruning # v1.1.1

or

git clone https://github.com/VainF/Torch-Pruning.git # recommended

Quickstart

Here we provide a quick start for Torch-Pruning. More explained details can be found in tutorals

0. How it works

In complex network structures, dependencies can arise among groups of parameters, necessitating their simultaneous pruning. Our work addresses this challenge by providing an automated mechanism for grouping parameters to facilitate their efficient removal for acceleration. Specifically, Torch-Pruning accomplishes this by forwarding your model with a fake input, tracing the network to establish a graph, and recording the dependencies between layers. When you prune a single layer, Torch-Pruning identifies and groups all coupled layers by returning a tp.Group. Moreover, all pruning indices will be automatically aligned if operations like torch.split or torch.cat are present.

With DepGraph, it is easy to design some "group-level" criteria to estimate the importance of a whole group rather than a single layer. In our paper, we craft a simple GroupPruner (c) to learn consistent sparsity across coupled layers.

1. A minimal example

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 2. Specify the to-be-pruned channels. Here we prune those channels indexed by [2, 6, 9].
pruning_idxs = [2, 6, 9]
pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )

print(pruning_group.details())  # or print(pruning_group)

# 3. prune all grouped layers that are coupled with model.conv1 (included).
if DG.check_pruning_group(pruning_group): # avoid full pruning, i.e., channels=0.
    pruning_group.prune()

# 4. save & load the pruned model 
torch.save(model, 'model.pth') # save the model object
model_loaded = torch.load('model.pth') # no load_state_dict

This example demonstrates the fundamental pruning pipeline using DepGraph. Note that resnet.conv1 is coupled with several layers. Let's print the resulting group and observe how a pruning operation "triggers" other ones. In the following outputs, A => B means the pruning operation A triggers the pruning operation B. group[0] refers to the pruning root specified by DG.get_pruning_group.

--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs=[2, 6, 9] (Pruning Root)
[1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp(ReluBackward0), idxs=[2, 6, 9]
[3] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0), idxs=[2, 6, 9]
[4] prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp(AddBackward0), idxs=[2, 6, 9]
[5] prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[6] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[7] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on _ElementWiseOp(ReluBackward0), idxs=[2, 6, 9]
[8] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_out_channels on _ElementWiseOp(AddBackward0), idxs=[2, 6, 9]
[9] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[10] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[11] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on _ElementWiseOp(ReluBackward0), idxs=[2, 6, 9]
[12] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs=[2, 6, 9]
[13] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
--------------------------------

For more details about grouping, please refer to tutorials/2 - Exploring Dependency Groups

How to scan all groups:

Just like what we do in the MetaPruner, one can use DG.get_all_groups(ignored_layers, root_module_types) to scan all groups sequentially. Each group will begin with a layer that matches a type in the "root_module_types" parameter. By default, these groups contain a full index list idxs=[0,1,2,3,...,K] that covers all prunable parameters. If you are intended to prune only partial channels/dimensions, you can use group.prune(idxs=idxs).

for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):
    # handle groups in sequential order
    idxs = [2,4,6] # your pruning indices
    group.prune(idxs=idxs)
    print(group)

2. High-level Pruners

Leveraging the DependencyGraph, we developed several high-level pruners in this repository to facilitate effortless pruning. By specifying the desired channel sparsity, you can prune the entire model and fine-tune it using your own training code. For detailed information on this process, we encourage you to consult the this tutorial. Additionally, you can find more practical examples in benchmarks/main.py.

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True)

# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 5 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    # finetune your model here
    # finetune(model)
    # ...

Sparse Training

Some pruners like BNScalePruner and GroupNormPruner require sparse training before pruning. This can be easily achieved by inserting just one line of code pruner.regularize(model) in your training script. The pruner will update the gradient of trainable parameters.

for epoch in range(epochs):
    model.train()
    for i, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = F.cross_entropy(out, target)
        loss.backward()
        pruner.regularize(model) # <== for sparse learning
        optimizer.step()

Interactive Pruning

All high-level pruners support interactive pruning. You can use pruner.step(interactive=True) to get all groups and interactively prune them by calling group.prune(). This feature is useful if you want to control/monitor the pruning process.

for i in range(iterative_steps):
    for group in pruner.step(interactive=True): # Warning: groups must be handled sequentially. Do not keep them as a list.
        print(group) 
        # do whatever you like with the group 
        # ...
        group.prune() # you should manually call the group.prune()
        # group.prune(idxs=[0, 2, 6]) # you can even change the pruning behaviour with the idxs parameter
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    # finetune your model here
    # finetune(model)
    # ...

3. Low-level pruning functions

While it is possible to manually prune your model using low-level functions, this approach can be quite laborious, as it requires careful management of the associated dependencies. As a result, we recommend utilizing the aforementioned high-level pruners to streamline the pruning process.

tp.prune_conv_out_channels( model.conv1, idxs=[2,6,9] )

# fix the broken dependencies manually
tp.prune_batchnorm_out_channels( model.bn1, idxs=[2,6,9] )
tp.prune_conv_in_channels( model.layer2[0].conv1, idxs=[2,6,9] )
...

The following pruning functions are available:

tp.prune_conv_out_channels,
tp.prune_conv_in_channels,
tp.prune_depthwise_conv_out_channels,
tp.prune_depthwise_conv_in_channels,
tp.prune_batchnorm_out_channels,
tp.prune_batchnorm_in_channels,
tp.prune_linear_out_channels,
tp.prune_linear_in_channels,
tp.prune_prelu_out_channels,
tp.prune_prelu_in_channels,
tp.prune_layernorm_out_channels,
tp.prune_layernorm_in_channels,
tp.prune_embedding_out_channels,
tp.prune_embedding_in_channels,
tp.prune_parameter_out_channels,
tp.prune_parameter_in_channels,
tp.prune_multihead_attention_out_channels,
tp.prune_multihead_attention_in_channels,

4. Customized Layers

Please refer to tests/test_customized_layer.py.

5. Benchmarks

Our results on {ResNet-56 / CIFAR-10 / 2.00x}

Method Base (%) Pruned (%) $\Delta$ Acc (%) Speed Up
NIPS [1] - - -0.03 1.76x
Geometric [2] 93.59 93.26 -0.33 1.70x
Polar [3] 93.80 93.83 +0.03 1.88x
CP [4] 92.80 91.80 -1.00 2.00x
AMC [5] 92.80 91.90 -0.90 2.00x
HRank [6] 93.26 92.17 -0.09 2.00x
SFP [7] 93.59 93.36 +0.23 2.11x
ResRep [8] 93.71 93.71 +0.00 2.12x
Ours-L1 93.53 92.93 -0.60 2.12x
Ours-BN 93.53 93.29 -0.24 2.12x
Ours-Group 93.53 *93.91 +0.38 2.13x

Please refer to benchmarks for more details.

Citation

@article{fang2023depgraph,
  title={DepGraph: Towards Any Structural Pruning},
  author={Fang, Gongfan and Ma, Xinyin and Song, Mingli and Mi, Michael Bi and Wang, Xinchao},
  journal={The IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year={2023}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-pruning-1.1.2.tar.gz (39.9 kB view hashes)

Uploaded Source

Built Distribution

torch_pruning-1.1.2-py3-none-any.whl (35.2 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page