tf2pb: tensorflow model ckpt ,h5 convert to pb or serving pb
Project description
tf2pb: tensorflow model ckpt ,h5 convert to pb or serving pb
-- coding: utf-8 --
1. tf ckpt convert to pb , tf h5 convert pb
支持普通pb和fastertransformer pb转换
1. fastertransformer pb 可提高1.9x - 3.x加速, fastertransformer 目前只支持官方bert transformer系列
2. keras h5py模型转换pb
建议pb模型均可以通过nn-sdk推理
fastertransformer 4.0
#cuda 11.3 pip install fastertransformer==4.0.0.113
#cuda 11.6 pip install fastertransformer==4.0.0.116
fastertransformer 5.0
#cuda 11.3 pip install fastertransformer==5.0.0.113
#cuda 11.6 pip install fastertransformer==5.0.0.116
推荐 tensorflow 链接如下,建议使用cuda11.3.1 环境tensorflow 1.15
tensorflow链接: https://pan.baidu.com/s/1PXelYOJ2yqWfWfY7qAL4wA 提取码: rpxv 复制这段内容后打开百度网盘手机App,操作更方便哦
tf经过测试 , bert 加速3.x
2. ckpt convert pb
# -*- coding: utf-8 -*-
import os
import tensorflow as tf
import shutil
import tf2pb
#if not fastertransformer , don't advice change
try:
#cuda 11.3 pip install fastertransformer==4.0.0.113
#cuda 11.6 pip install fastertransformer==4.0.0.116
import fastertransformer
convert_config = {
"fastertransformer": {
"floatx": "float32",
"remove_padding": False,
"int8_mode": 0, # need nvidia card supoort,do not suggest
}
}
except:
convert_config = {}
# BertModel_module 加载官方bert模型和fastertransformer模型
#如果是正常pb, 可以直接使用官方modeling 模块 import modeling
def load_model_tensor(bert_config_file,max_seq_len,num_labels):
BertModel_module = tf2pb.get_modeling(convert_config)
if BertModel_module is None:
raise Exception('tf2pb get_modeling failed')
bert_config = BertModel_module.BertConfig.from_json_file(bert_config_file)
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, num_labels, use_one_hot_embeddings):
"""Creates a classification model."""
model = BertModel_module.BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings)
output_layer = model.get_pooled_output()
hidden_size = output_layer.shape[-1].value
output_weights = tf.get_variable(
"output_weights", [num_labels, hidden_size],
dtype="float32",
initializer=tf.truncated_normal_initializer(stddev=0.02))
output_bias = tf.get_variable(
"output_bias", [num_labels],
dtype="float32",
initializer=tf.zeros_initializer())
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
probabilities = tf.nn.softmax(logits, axis=-1)
return probabilities
input_ids = tf.placeholder(tf.int32, (None, max_seq_len), 'input_ids')
input_mask = tf.placeholder(tf.int32, (None, max_seq_len), 'input_mask')
segment_ids = None
# 这里简单使用分类,具体根据自己需求修改
probabilities = create_model(bert_config, False, input_ids, input_mask, segment_ids, num_labels, False)
save_config = {
"input_tensor": {
'input_ids': input_ids,
'input_mask': input_mask
},
"output_tensor": {
"pred_ids": probabilities
},
}
return save_config
if __name__ == '__main__':
# 训练ckpt权重
weight_file = r'/home/tk/tk_nlp/script/ner/ner_output/bert/model.ckpt-2704'
output_dir = r'/home/tk/tk_nlp/script/ner/ner_output/bert'
bert_config_file = r'/data/nlp/pre_models/tf/bert/chinese_L-12_H-768_A-12/bert_config.json'
if not os.path.exists(bert_config_file):
raise Exception("bert_config does not exist")
max_seq_len = 340
num_labels = 16 * 4 + 1
#normal pb
pb_config = {
"ckpt_filename": weight_file, # 训练ckpt权重
"save_pb_file": os.path.join(output_dir,'bert_inf.pb'),
}
#serving pb
pb_serving_config = {
'use':False,#默认注释掉保存serving模型
"ckpt_filename": weight_file, # 训练ckpt权重
"save_pb_path_serving": os.path.join(output_dir,'serving'), # tf_serving 保存模型路径
'serve_option': {
'method_name': 'tensorflow/serving/predict',
'tags': ['serve'],
}
}
if pb_config['save_pb_file'] and os.path.exists(pb_config['save_pb_file']):
os.remove(pb_config['save_pb_file'])
if pb_serving_config['use'] and pb_serving_config['save_pb_path_serving'] and os.path.exists(pb_serving_config['save_pb_path_serving']):
shutil.rmtree(pb_serving_config['save_pb_path_serving'])
def convert2pb(is_save_serving):
def create_network_fn():
save_config = load_model_tensor(bert_config_file=bert_config_file,max_seq_len=max_seq_len,num_labels=num_labels)
save_config.update(pb_serving_config if is_save_serving else pb_config)
return save_config
if not is_save_serving:
ret = tf2pb.freeze_pb(create_network_fn)
if ret ==0:
tf2pb.pb_show(pb_config['save_pb_file'])
else:
print('tf2pb.freeze_pb failed ',ret)
else:
ret = tf2pb.freeze_pb_serving(create_network_fn)
if ret ==0:
tf2pb.pb_serving_show(pb_serving_config['save_pb_path_serving'],pb_serving_config['serve_option']['tags']) # 查看
else:
print('tf2pb.freeze_pb_serving failed ',ret)
convert2pb(is_save_serving = False)
if pb_serving_config['use']:
convert2pb(is_save_serving = True)
3. h5 convert pb
import sys
import tensorflow as tf
import tf2pb
import os
from keras.models import Model,load_model
# bert_model is construct by your src code
weight_file = os.path.join(output_dir, 'best_model.h5')
bert_model.load_weights(weight_file , by_name=False)
# or bert_model = load_model(weight_file)
print(bert_model.inputs)
#modify output name
pred_ids = tf.identity(bert_model.output, "pred_ids")
config = {
'model': bert_model,# the model your trained
'input_tensor' : {
"Input-Token": bert_model.inputs[0], # Tensor such as bert.Input[0]
"Input-Segment": bert_model.inputs[1], # Tensor such as bert.Input[0]
},
'output_tensor' : {
"pred_ids": pred_ids, # Tensor output tensor
},
'save_pb_file': r'/root/save_pb_file.pb', # pb filename
}
if os.path.exists(config['save_pb_file']):
os.remove(config['save_pb_file'])
#直接转换
tf2pb.freeze_keras_pb(config)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
No source distribution files available for this release.See tutorial on generating distribution archives.
Built Distribution
tf2pb-0.2.0-py3-none-any.whl
(92.3 kB
view details)
File details
Details for the file tf2pb-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: tf2pb-0.2.0-py3-none-any.whl
- Upload date:
- Size: 92.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.5.0 importlib_metadata/4.8.2 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.8.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | cc7f0b0d73794b28f1d06e412136c6e7af132051828c718b7a45ba7a2cafa21a |
|
MD5 | 97c2424007511a680a535da9ac155dea |
|
BLAKE2b-256 | 410dc4723ce2d480aac0e3b800d72247af776813966147f3627b928b659f282d |