Pruning Engine for CNN
Project description
Pruning Engine
Introduction
This project is a pruning engine for neural network models. It provides a set of tools and methods for pruning the weights and filters of neural networks. Pruning is a technique used to reduce the size of neural networks by removing unnecessary parameters, improving model efficiency, and reducing computational resources required for model inference.
Experiment Result
Figure 1: Experiment Result
Pretrain Weight for reproduce the experiment
Installation
To install the project, follow these steps:
Clone the repository:
git clone https://github.com/MIC-Laboratory/CNN-Pruning-Engine.git
Install the dependencies:
pip install -r requirements.txt
Usage
To use the pruning engine, follow these steps:
- Import the pruning engine module:
from Pruning_engine import pruning_engine
- Create an instance of the pruning engine:
pruner = pruning_engine.PruningEngine(pruning_method="L1norm",individual = True)
- Load your neural network model:
from torchvision.models import vgg16_bn,VGG16_BN_Weights
weights = VGG16_BN_Weights.DEFAULT
model = vgg16_bn(weights=weights)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
- Based on the structure of the Neural Network, choose the layer that needs to be pruned. E.g. VGG16[1]
VGG16 Configuration in Python Code
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
...
)
)
Let say if I want the first conv2d layer prune 10% of the filter, the following batchnorm layer also prune 10% and the following conv2d layer prune 10% of the kernel by using L1norm criterion
pruned_layer = model.features[0]
pruner.set_pruning_ratio(0.1)
pruner.set_layer(pruned_layer,main_layer=True)
remove_filter_idx = pruner.get_remove_filter_idx()["current_layer"]
model.features[0] = pruner.remove_filter_by_index(remove_filter_idx)
pruned_layer = model.features[1]
pruner.set_pruning_ratio(0.1)
pruner.set_layer(pruned_layer)
remove_filter_idx = pruner.get_remove_filter_idx()["current_layer"]
model.features[1] = pruner.remove_Bn(remove_filter_idx)
pruned_layer = model.features[3]
pruner.set_pruning_ratio(0.1)
pruner.set_layer(pruned_layer)
remove_filter_idx = pruner.get_remove_filter_idx()["current_layer"]
model.features[3] = pruner.remove_kernel_by_index(remove_filter_idx)
-
Retrain the model, reference to the training repo: https://github.com/MIC-Laboratory/Pytorch-Cifar
-
Save the pruned model:
torch.save(model, 'path_to_pruned_model.pt')
Pruning Methods
The pruning engine supports multiple pruning methods, including:
- L1 norm pruning: Removes least important weights based on their L1 norm.
- K-means clustering pruning: Clusters weights and removes weak clusters based on their importance.
- Taylor pruning: Measures weight importance using the Taylor expansion and removes less important weights.
Project Structure
Pruning Engine Architecture
Figure 2: Pruning Engine UML diagram
| Functions Name | Description |
|---|---|
| Pruning_base | Provides basic pruning functionalities such as removing filters by index and kernels by index. |
| Kmean_base | Offers basic K-means clustering pruning algorithms. Given weights, k, and a sorting method within each group, it returns the indices of filters that are not important. |
| Pruning_engine | Main file for the pruning engine, integrating all the pruning methods. |
| K_L1norm | Implements K-means clustering to cluster similar CNN filters and uses L1 norm to sort the filters within each cluster. |
| K_Taylor | Applies K-means clustering to cluster similar filters and uses Taylor estimation to sort the filters within each cluster. |
| L1norm | Pruning method that considers filters with the least magnitude as non-important. |
| Taylor | Pruning method that computes the product of feature maps and gradients for each filter, with smaller values indicating less importance. |
Add your custom pruning criterion,
- Create a Python file within the "pruning_criterion" folder, containing your pruning algorithm. This file should include a function that accepts the layer weight as a parameter and return ranks the filters based on their importance using your specific algorithm.
- In the pruning engine file, follow the structure of other pruning methods. In the __init__ function of the Pruning Engine, initialize your distinct pruning algorithm, similar to the way other methods are initialized.
Test Case Architecture
Figure 3: Testcase UML diagram
Adding a custom network.
To incorporate your own network for the pruning experiment,
- Place your model file in the "Models" folder, and then initialize the new network in "testcase_base.py".
- Create a new Python file in the "Example" folder for pruning your model, along with the corresponding config file. This file should include the following fully implemented functions: pruning, hook_function, and get_layer_store. These functions are essential for implementing the K-means and Taylor-related pruning methods. You can refer to the examples of other networks to implement them.
Here is a brief description of each function:
| Function Name | Description |
|---|---|
| pruning | This function is used to remove filters and kernels from the CNN. It does not have any return values or parameters. |
| hook_function | This function allows you to hook the feature map and gradient of filters in the CNN. It is required for the Taylor relative pruning algorithm. |
| get_layer_store | This function specifies the layers used to store additional pruning information, such as the number of clusters in the K-means algorithm. These layers should be the same ones that you want to hook the feature map and gradient to. |
Reference
[1] Simonyan, K. and Zisserman, A., 2014. Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556.
Contributing
Contributions to this project are welcome. If you find any bugs or have suggestions for new features, please open an issue on the GitHub repository.
License
This project is licensed under the MIT License. See the LICENSE file for more information.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pruning_engine_sfsu_miclab-1.0.3.tar.gz.
File metadata
- Download URL: pruning_engine_sfsu_miclab-1.0.3.tar.gz
- Upload date:
- Size: 15.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.9.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ee30e9d146fbb6889fb78a93f7b82c0e2f8cd488826424073dd92750ce2ef9d9
|
|
| MD5 |
fa978fbcd76d56592fad69a9b22ecaf5
|
|
| BLAKE2b-256 |
c52573d740f28c580935e42300e4c63695f93027f652ea3840202af9db8c3735
|
File details
Details for the file Pruning_Engine_SFSU_MICLab-1.0.3-py3-none-any.whl.
File metadata
- Download URL: Pruning_Engine_SFSU_MICLab-1.0.3-py3-none-any.whl
- Upload date:
- Size: 16.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.9.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
55240092508d34c1eec251a3f4a8523e8747db9824c34dd2b68decf003791e4d
|
|
| MD5 |
3806dddc3a08ea8989355fce7fbfd29c
|
|
| BLAKE2b-256 |
08004617a55e8884668e4f0cf7196968b674a6d6bd93dc0315667ed55796f191
|