Folds BN layers in tf keras models.
Project description
Batch-Normalization Folding
In this repository, we propose an implementation of the batch-normalization folding algorithm from IJCAI 2022. Batch-Normalization Folding implements the batch normalization layer by folding it into a appropriate layer. The original batch-normalization layers are removed without changing the predictive function defined by the neural network. The simpliest scenario is an application for a fully-connected layer followed by a batch-normalization layer, we get
x \mapsto \gamma \frac{Ax + b - \mu}{\sigma + \epsilon} + \beta = \gamma \frac{A}{\sigma +\epsilon} x + \frac{b - \mu}{\sigma + \epsilon} + \beta
Thus the two layers can be expressed as a single fully-connected layer at inference without any change in the predictive function.
Use
This repository is available as a pip package (use pip install tensorflow-batchnorm-folding
).
This implementation is compatible with tf.keras.Model instances. It was tested with the following models
- ResNet 50
- MobileNet V2
- MobileNet V3
- EfficentNet B0
To run a simple test:
from batch_normalization_folding.folder import fold_batchnormalization_layers
import tensorflow as tf
mod=tf.keras.applications.efficientnet.EfficientNetB0()
folded_model,output_str=fold_batchnormalization_layers(mod,verbose=True)
The output_str
is either the ratio num_layers_folded/num_layers_not_folded or 'failed' to state a failure in the process.
Parameters
The function fold_batchnormalization_layers
can be called with multiple parameters.
Parameter | Option |
---|---|
model | Model to fold. |
folding_mechanism | ban-off uses BaN-OFF als folding mechanism (recommended) simple uses a folding mechanism that only folds neighboring layers. Default: ban-off |
verbose | True prints additional information during the folding process. False Disables the printing of additional information. Default: False |
Layer folding
This Python implementation supports folding into many types of layers, supported is the folding of batch-normalization parameters into:
- Dense
- Conv1D
- Conv2D
- DepthwiseConv2D
To Do
- unit test on all keras applciations models
- check package installement
- deal with Concatenate layers
Cite
@inproceedings{yvinec2022fold,
title={To Fold or Not to Fold: a Necessary and Sufficient Condition on Batch-Normalization Layers Folding},
author={Yvinec, Edouard and Dapogny, Arnaud and Bailly, Kevin},
journal={IJCAI},
year={2022}
}
Performance on Base Models
+------------------------------------+
| ResNet 50 |
+------------------------------------+
| BN layers folded | 53 |
| BN layers not folded | 0 |
+------------------------------------+
| EfficientNet B0 |
+------------------------------------+
| BN layers folded | 49 |
| BN layers not folded | 0 |
+------------------------------------+
| MobileNet V2 |
+------------------------------------+
| BN layers folded | 52 |
| BN layers not folded | 0 |
+------------------------------------+
| MobileNet V3 |
+------------------------------------+
| BN layers folded | 34 |
| BN layers not folded | 0 |
+------------------------------------+
| Inception ResNet V2 |
+------------------------------------+
| BN layers folded | 204 |
| BN layers not folded | 0 |
+------------------------------------+
| Inception V3 |
+------------------------------------+
| BN layers folded | 94 |
| BN layers not folded | 0 |
+------------------------------------+
| NASNet |
+------------------------------------+
| BN layers folded | 28 |
| BN layers not folded | 164 |
+------------------------------------+
| DenseNet 121 |
+------------------------------------+
| BN layers folded | 59 |
| BN layers not folded | 62 |
+------------------------------------+
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for tensorflow_batchnorm_folding-1.0.9.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | fd7e600635cb7e661ef70171e77d19b107c526ae1ac1992f8afe24f198ef0c0d |
|
MD5 | e54c86babec615432461053eca211965 |
|
BLAKE2b-256 | ed3548ea2c5ed08e80ca52820e3c9a244c4743262b4488f9c5e266ea84d01e18 |
Hashes for tensorflow_batchnorm_folding-1.0.9-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e8f86010acbf7b30cc7511fa458df4e717db32aad38fab07bcc7da89813bff74 |
|
MD5 | 16483791259396f0a27613aa34b15a0f |
|
BLAKE2b-256 | 4ccf6c6892cd1d23372c41efbbadb9c0560601f3d5235194b3958ba7a7c204a4 |