Self-Created Tools to convert ONNX files (NCHW) to TensorFlow format (NHWC). The purpose of this tool is to solve the massive Transpose extrapolation problem in onnx-tensorflow (onnx-tf).
Project description
onnx2tf
Self-Created Tools to convert ONNX files (NCHW) to TensorFlow format (NHWC). The purpose of this tool is to solve the massive Transpose extrapolation problem in onnx-tensorflow (onnx-tf). I don't need a Star, but give me a pull request.
Key concept
- onnx-tensorflow is a very useful tool, but the performance of the generated TensorFlow models is significantly degraded due to the extrapolation of a large number of
Transpose
OPs before and after each OP during the format conversion fromNCHW
toNHWC
. Therefore, I will make this tool myself as a derivative tool of onnx-tensorflow without extrapolatingTranspose
. - Most of the internal processing of the tool is full-scratch, but some of the more complex OPs have been adapted from onnx-tensorflow. I am very grateful to the engineers at International Business Machines Corporation / LeapMind / Microsoft for developing onnx-tensorflow.
- Not only does it handle conversions of 4-dimensional inputs, such as
NCHW
toNHWC
, but also the number of input dimensions in 3, 5, or even more dimensions. For example,NCDHW
toNDHWC
, etc. However, since 1-D, 2-D, 3-D and 6-D input may produce patterns that are mechanically difficult to convert, it should be possible to give parameters to externally modify the tool's behavior. See Parameter replacement - If there are undefined dimensions in the input OP, the model structure is not fully optimized and conversion errors are very likely to occur.
- Immediately following a
Reshape
OP with dimensional compression and dimensional decompression, there is a 95% probability that the model transformation operation will be disrupted and errors will occur. For example, patterns such as[1,200,200,5]
->[1,200,-1]
or[10,20,30,40,50]
->[10,2,10,30,10,4,50]
orFlatten
. - TensorFlow's Convolution does not have an equivalent operation to ONNX's Padding operation. Therefore, a
Pad
OP is inserted immediately before a Convolution with Padding of size greater than 1. - Support conversion to TensorFlow saved model and TFLite (Float32/Float16).
- Does not support quantization to INT8. For quantization, use the official TensorFlow converter to convert from saved_model to your own.
- Files exceeding the Protocol Buffers file size limit of 2GB are not supported. Therefore, the external format is not supported at the initial stage of tool creation.
- If there are ONNX OPs that are not supported by TensorFlow, use simple-onnx-processing-tools to replace them with harmless OPs in advance and then use this tool to convert them. In other words, you can convert any model with your efforts.
- ONNX splitting, merging, generating OPs, rewriting OP attributes, BGR<->RGB conversion, converting to JSON and editing in the IDE, batch size changes for undefined dimensions, and various other processing can be done with the simple-onnx-processing-tools. Therefore, it is recommended that models with very complex structures be converted to TFLite after modifying the structure beforehand.
-
BatchNormalization
supports only inference mode. - Only for
opset=11
or higher - If you do not like the generated TFLite OP name, edit it using tflite2json2tflite.
- The generated Keras models cannot be used for retraining. If you want to train, you must build your own model.
- When converting to TensorFlow.js, CoreML, etc., please generate saved_model with the
--output_signaturedefs
option and use the generated saved_model to convert with various converters. tensorflowjs_converter, coremltools, edgetpu_compilier - There are many OPs on ONNX that do not support EdgeTPU. Therefore, if you need to generate an EdgeTPU model, please specify
--replace_***_to_pseudo_***
to convert your model. onnx2tf will attempt to replace the OP with an EdgeTPU-compatible OP whenever possible. - The main factors that cause accuracy degradation after model conversion are as follows
- differences in Padding specifications
- difference in Python division specification in the process of model transformation (error due to even rounding)
- Divide epsilon without consideration
- deprecated TrueDivision
- support difference of powers
- differences in interpolation operation specifications during resizing
- Difference in arithmetic precision supported by each operation
- Calculation error due to scaling up or down by specifying a
scale
when resizing images
The above differences often cannot be dealt with by simply converting the model in a straightforward manner. Therefore, you need to replace the model yourself in advance with an operation that is less prone to errors.
- Implement the
Resize
process for the 5D tensor. - Add process to replace
Asin
withpseudo-Asin
. - Add process to replace
Acos
withpseudo-Acos
. - Add process to replace
GatherND
withpseudo-GatherND
. - Add process to replace
HardSwish
withpseudo-HardSwish
. - Add process to replace
GridSample
withpseudo-GridSample
. - Add process to replace
LeakyRelu
withpseudo-LeakyRelu
. - Add process to replace
Power
withpseudo-Power
. - Add process to replace
Neg
withpseudo-Neg
. - Added option to fix dynamic batch size
N
to a specified number. - Automatically run onnx-simplifier (onnxsim) backend and optimize onnx files before model transformation.
Demo
Video speed is adjusted approximately 50 times slower than actual speed.
Environment
- onnx
- onnx-simplifier
- onnx_graphsurgeon
- simple_onnx_processing_tools
- tensorflow>=2.10.0
Sample Usage
$ docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.0.13
or
$ pip install -U onnx2tf
or
$ pip install -e .
$ wget https://github.com/PINTO0309/onnx2tf/releases/download/0.0.2/resnet18-v1-7.onnx
$ onnx2tf -i resnet18-v1-7.onnx -o saved_model
CLI Parameter
$ onnx2tf -h
usage: onnx2tf
[-h]
-i INPUT_ONNX_FILE_PATH
[-o OUTPUT_FOLDER_PATH]
[-osd]
[-nuo]
[-b BATCH_SIZE]
[-ois OVERWRITE_INPUT_SHAPE [OVERWRITE_INPUT_SHAPE ...]]
[-k KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES [KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES ...]]
[-rari64 | -rarf32]
[-rasin]
[-racos]
[-rlr]
[-rpw]
[-rgn]
[-rng]
[-rhs]
[-me]
[-prf PARAM_REPLACEMENT_FILE]
[-n]
optional arguments:
-h, --help
show this help message and exit
-i INPUT_ONNX_FILE_PATH, --input_onnx_file_path INPUT_ONNX_FILE_PATH
Input onnx file path.
-o OUTPUT_FOLDER_PATH, --output_folder_path OUTPUT_FOLDER_PATH
Output folder path. Default: "saved_model"
-osd, --output_signaturedefs
Signature is added to the output for serving or for conversion
to other model formats. However, this can significantly reduce the speed
of model conversion and significant increase the size of the model.
-nuo, --not_use_onnxsim
No optimization by onnx-simplifier is performed.
If this option is used, the probability of a conversion error is very high.
-b BATCH_SIZE, --batch_size BATCH_SIZE
Fixes the dynamic batch size to the specified numeric batch size.
A value of 1 or more must be specified.
-ois OVERWRITE_INPUT_SHAPE [OVERWRITE_INPUT_SHAPE ...], \
--overwrite_input_shape OVERWRITE_INPUT_SHAPE [OVERWRITE_INPUT_SHAPE ...]
Overwrite the input shape.
The format is
"i1:dim0,...,dimN" "i2:dim0,...,dimN" "i3:dim0,...,dimN"
When there is only one input, for example,
"data:1,3,224,224"
When there are multiple inputs, for example,
"data1:1,3,224,224" "data2:1,3,112" "data3:5"
A value of 1 or more must be specified.
Numerical values other than dynamic dimensions are ignored.
Ignores --batch_size if specified at the same time as --batch_size.
-k KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES [KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES ...], \
--keep_ncw_or_nchw_or_ncdhw_input_names KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES \
[KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES ...]
Holds the NCW or NCHW or NCDHW of the input shape for the specified INPUT OP names.
If a nonexistent INPUT OP name is specified, it is ignored.
Valid only for 3D, 4D and 5D input tensors.
e.g. --keep_ncw_or_nchw_or_ncdhw_input_names "input0" "input1" "input2"
-rari64, --replace_argmax_to_reducemax_and_indicies_is_int64
Replace ArgMax with a ReduceMax. The returned indicies are int64.
Only one of replace_argmax_to_reducemax_and_indicies_is_int64
and replace_argmax_to_reducemax_and_indicies_is_float32 can be specified.
-rarf32, --replace_argmax_to_reducemax_and_indicies_is_float32
Replace ArgMax with a ReduceMax. The returned indicies are float32.
Only one of replace_argmax_to_reducemax_and_indicies_is_int64
and replace_argmax_to_reducemax_and_indicies_is_float32 can be specified.
-rasin, --replace_asin_to_pseudo_asin
Replace Asin with a pseudo Asin.
-racos, --replace_acos_to_pseudo_acos
Replace Acos with a pseudo Acos.
-rlr, --replace_leakyrelu_to_pseudo_leakyrelu
Replace LeakyReLU with a pseudo LeakyReLU.
-rpw, --replace_power_to_pseudo_power
Replace Power with a pseudo Power.
-rgn, --replace_gathernd_to_pseudo_gathernd
Replace GatherND with a pseudo GatherND.
-rng, --replace_neg_to_pseudo_neg
Replace Neg with a pseudo Neg.
-rhs, --replace_hardswish_to_pseudo_hardswish
Replace HardSwish with a pseudo HardSwish.
-me, --mvn_epsilon
For MeanVarianceNormalization.
The number to be added to the variance to avoid division by zero
when normalizing the value.
(input_tensor - mean) / tf.sqrt(variance + mvn_epsilon)
Default: 0.0000000001
-prf PARAM_REPLACEMENT_FILE, --param_replacement_file PARAM_REPLACEMENT_FILE
Parameter replacement file path. (.json)
-n, --non_verbose
Do not show all information logs. Only error logs are displayed.
In-script Usage
>>> from onnx2tf import convert
>>> help(convert)
Help on function convert in module onnx2tf:
convert(
input_onnx_file_path: Union[str, NoneType] = '',
onnx_graph: Union[onnx.onnx_ml_pb2.ModelProto, NoneType] = None,
output_folder_path: Union[str, NoneType] = 'saved_model',
output_signaturedefs: Optional[bool] = False,
not_use_onnxsim: Optional[bool] = False,
batch_size: Union[int, NoneType] = None,
overwrite_input_shape: Union[List[str], NoneType] = None,
keep_ncw_or_nchw_or_ncdhw_input_names: Union[List[str], NoneType] = None,
replace_argmax_to_reducemax_and_indicies_is_int64: Union[bool, NoneType] = False,
replace_argmax_to_reducemax_and_indicies_is_float32: Union[bool, NoneType] = False,
replace_asin_to_pseudo_asin: Union[bool, NoneType] = False,
replace_acos_to_pseudo_acos: Union[bool, NoneType] = False,
replace_leakyrelu_to_pseudo_leakyrelu: Union[bool, NoneType] = False,
replace_power_to_pseudo_power: Optional[bool] = False,
replace_gathernd_to_pseudo_gathernd: Optional[bool] = False,
replace_neg_to_pseudo_neg: Optional[bool] = False,
replace_hardswish_to_pseudo_hardswish: Optional[bool] = False,
mvn_epsilon: Union[float, NoneType] = 0.0000000001,
param_replacement_file: Optional[str] = '',
non_verbose: Union[bool, NoneType] = False
) -> keras.engine.training.Model
Convert ONNX to TensorFlow models.
Parameters
----------
input_onnx_file_path: Optional[str]
Input onnx file path.
Either input_onnx_file_path or onnx_graph must be specified.
onnx_graph: Optional[onnx.ModelProto]
onnx.ModelProto.
Either input_onnx_file_path or onnx_graph must be specified.
onnx_graph If specified, ignore input_onnx_file_path and process onnx_graph.
output_folder_path: Optional[str]
Output tensorflow model folder path.
Default: "saved_model"
output_signaturedefs: Optional[bool]
Signature is added to the output for serving or for conversion
to other model formats. However, this can significantly reduce the speed
of model conversion and significant increase the size of the model.
not_use_onnxsim: Optional[bool]
No optimization by onnx-simplifier is performed.
If this option is used, the probability of a conversion error is very high.
batch_size: Optional[int]
Fixes the dynamic batch size to the specified numeric batch size.
A value of 1 or more must be specified.
overwrite_input_shape: Optional[List[str]]
Overwrite the input shape.
The format is
['i1:dim0,dim1,...,dimN' 'i2:dim0,dim1,...,dimN' 'i3:dim0,dim1,...,dimN']
When there is only one input, for example,
['data:1,3,224,224']
When there are multiple inputs, for example,
['data1:1,3,224,224','data2:1,3,112','data3:5']
A value of 1 or more must be specified.
Numerical values other than dynamic dimensions are ignored.
Ignores --batch_size if specified at the same time as --batch_size.
keep_ncw_or_nchw_or_ncdhw_input_names: Optional[List[str]]
Holds the NCW or NCHW or NCDHW of the input shape for the specified INPUT OP names.
If a nonexistent INPUT OP name is specified, it is ignored.
Valid only for 3D, 4D and 5D input tensors.
e.g.
--keep_ncw_or_nchw_or_ncdhw_input_names=['input0', 'input1', 'input2']
replace_argmax_to_reducemax_and_indicies_is_int64: Optional[bool]
Replace ArgMax with a ReduceMax. The returned indicies are int64.
Only one of replace_argmax_to_reducemax_and_indicies_is_int64 and
replace_argmax_to_reducemax_and_indicies_is_float32 can be specified.
Default: False
replace_argmax_to_reducemax_and_indicies_is_float32: Optional[bool]
Replace ArgMax with a ReduceMax. The returned indicies are float32.
Only one of replace_argmax_to_reducemax_and_indicies_is_int64 and
replace_argmax_to_reducemax_and_indicies_is_float32 can be specified.
Default: False
replace_asin_to_pseudo_asin: Optional[bool]
Replace Asin with a pseudo Asin.
replace_acos_to_pseudo_acos: Optional[bool]
Replace Acos with a pseudo Acos.
replace_leakyrelu_to_pseudo_leakyrelu: Optional[bool]
Replace LeakyReLU with a pseudo LeakyReLU.
replace_power_to_pseudo_power: Optional[bool]
Replace Power with a pseudo Power.
replace_gathernd_to_pseudo_gathernd: Optional[bool]
Replace GatherND with a pseudo GatherND.
replace_neg_to_pseudo_neg: Optional[bool]
Replace Neg with a pseudo Neg.
replace_hardswish_to_pseudo_hardswish: Optional[bool]
Replace HardSwish with a pseudo HardSwish.
mvn_epsilon: Optional[float]
For MeanVarianceNormalization.
The number to be added to the variance to avoid division by zero
when normalizing the value.
(input_tensor - mean) / tf.sqrt(variance + mvn_epsilon)
Default: 0.0000000001
param_replacement_file: Optional[str]
Parameter replacement file path. (.json)
non_verbose: Optional[bool]
Do not show all information logs. Only error logs are displayed.
Only one of replace_argmax_to_reducemax_and_indicies_is_int64 and
replace_argmax_to_reducemax_and_indicies_is_float32 can be specified.
Default: False
Returns
----------
model: tf.keras.Model
Model
Parameter replacement
This tool is used to convert NCW
to NWC
, NCHW
to NHWC
, NCDHW
to NDHWC
, NCDDHW
to NDDHWC
, NCDDDDDDHW
to NDDDDDDHWC
. Therefore, as stated in the Key Concepts, the conversion will inevitably break down at some point in the model. You need to look at the entire conversion log to see which OP transpositions are failing and correct them yourself. I dare to explain very little because I know that no matter how much detail I put in the README, you guys will not read it at all. attribute
or INPUT constant
or INPUT Initializer
can be replaced with the specified value.
- "A conversion error occurs."
- "Output results are wrong."
Please don't post such low level questions as issues.
-
convert option
--param_replacement_file param_replacement.json
-
param_replacement.json
{ "format_version": 1, "operations": [ { "op_name": "StatefulPartitionedCall/Tile_4", "param_target": "inputs", # attributes or inputs "param_name": "const_fold_opt__677", "values": [1,1,17] # Disable parameter transposition or overwrite parameters }, { "op_name": "StatefulPartitionedCall/Cast_3", "param_target": "attributes", # attributes or inputs "param_name": "to", "values": 1 # Disable parameter transposition or overwrite "to" parameters }, { "op_name": "Resize__697", "param_target": "inputs", "param_name": "Concat__696:0", "values": [26,26] # Replacement of unk__x (Resize OP, sizes height/width parameter) }, { "op_name": "Transpose__927", "param_target": "attributes", "param_name": "perm", "values": [0,1,2,3] # Disable parameter transposition or overwrite "perm" parameters }, { "op_name": "StatefulPartitionedCall/functional_1/max_unpooling2d_2/Reshape_1", "param_target": "inputs", "param_name": "const_fold_opt__911", "values": [4,131072] # Overwrite "shape" parameters }, { "op_name": "Reshape_25", "param_target": "outputs", "param_name": "onnx::InstanceNormalization_270", "post_process_transpose_perm": [0,2,1] # Extrapolate 3D Transpose after Reshape }, { "op_name": "Reshape_30", "param_target": "outputs", "param_name": "onnx::Mul_275", "post_process_transpose_perm": [0,2,3,1] # Extrapolate 4D Transpose after Reshape }, { "op_name": "flatten_1127", "param_target": "inputs", "param_name": "dropout0", "pre_process_transpose_perm": [0,3,1,2] } ] }
-
Replacement Supported OPs
No. OP type Remarks 1 Cast Type Values Type Values float16 10 int8 3 float32 1 int16 5 float64 11 int32 6 bool 9 int64 7 uint8 2 uint16 4 uint32 12 uint64 13 2 Div 3 Gemm 4 Mul 5 Reshape 1. "param_target": "inputs" values
: Value ofshape
pre_process_transpose_perm
: Transpose is applied to the tensor before the Reshape operation with the perm specified as pre-processing.
2. "param_target": "outputs"post_process_transpose_perm
: Transpose is applied to the tensor after the Reshape operation with the perm specified as post-processing.6 Flatten 1. "param_target": "attributes" axis
: Value ofaxis
2. "param_target": "inputs"pre_process_transpose_perm
: Transpose is applied to the tensor before the Reshape operation with the perm specified as pre-processing.
3. "param_target": "outputs"post_process_transpose_perm
: Transpose is applied to the tensor after the Reshape operation with the perm specified as post-processing.7 Resize 8 Sub 9 Tile 10 Transpose 1. "param_target": "attributes" perm
: Value ofperm
2. "param_target": "inputs"values
: Value oftensor
11 NonMaxSuppression
Generated Model
-
YOLOv7-tiny with Post-Process (NMS) ONNX to TFLite Float32 https://github.com/PINTO0309/onnx2tf/releases/download/0.0.33/yolov7_tiny_head_0.768_post_480x640.onnx
onnx2tf onnx-tensorflow
(Super redundant + Broken) -
YOLACT-Edge MobileNetV2 with Post-Process (MultiClass-NMS) ONNX to TFLite Float32 https://github.com/PINTO0309/onnx2tf/releases/download/1.0.11/yolact_edge_mobilenetv2_550x550.onnx
Related tools
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file onnx2tf-1.0.13.tar.gz
.
File metadata
- Download URL: onnx2tf-1.0.13.tar.gz
- Upload date:
- Size: 75.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.15
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0574b0de905d0f836c1af2742d06cf589caec2bda52f6ca7c37b4605bcbb8c5b |
|
MD5 | fc26ab21de166e1b40877c72a7d1aa50 |
|
BLAKE2b-256 | dc082d209f6d063381cf51d5f29a78cb161b5b91cb2c9d481fad69111d734a8c |
File details
Details for the file onnx2tf-1.0.13-py3-none-any.whl
.
File metadata
- Download URL: onnx2tf-1.0.13-py3-none-any.whl
- Upload date:
- Size: 176.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.15
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ade5c588db3f5f754b763b3a5dc9deb6587594b36c48993946cfbcdb924a7272 |
|
MD5 | 91e38bd9b578ccfec46c39f7a9ba8c98 |
|
BLAKE2b-256 | 165620b61cbab3c10cad694d02bd3ea051a82e5ad13f2c22b021e113ee863ca7 |