Skip to main content

tf2pb: tensorflow model ckpt ,h5 convert to pb or serving pb

Project description

tf2pb: tensorflow model ckpt ,h5 convert to pb or serving pb

-- coding: utf-8 --

1. tf ckpt convert to pb , tf h5 convert pb

    支持普通pb和fastertransformer pb转换
    1. fastertransformer pb 可提高1.9x - 3.x加速, fastertransformer 目前只支持官方bert transformer系列
    2. keras h5py模型转换pb

    建议pb模型均可以通过nn-sdk推理
    fastertransformer 4.0
         #cuda 11.3 pip install fastertransformer==4.0.0.113
         #cuda 11.6 pip install fastertransformer==4.0.0.116
    fastertransformer 5.0
         #cuda 11.3 pip install fastertransformer==5.0.0.113
         #cuda 11.6 pip install fastertransformer==5.0.0.116
    推荐 tensorflow 链接如下,建议使用cuda11.3.1 环境tensorflow 1.15
    tensorflow链接: https://pan.baidu.com/s/1PXelYOJ2yqWfWfY7qAL4wA 提取码: rpxv 复制这段内容后打开百度网盘手机App,操作更方便哦
    tf经过测试 , bert 加速3.x

2. ckpt convert pb

# -*- coding: utf-8 -*-
import os
import tensorflow as tf
import shutil
import tf2pb

#if not fastertransformer , don't advice change

try:
    #cuda 11.3 pip install fastertransformer==4.0.0.113
    #cuda 11.6 pip install fastertransformer==4.0.0.116
    import fastertransformer
    convert_config = {
        "fastertransformer": {
            "floatx": "float32",
            "remove_padding": False,
            "int8_mode": 0,  # need nvidia card supoort,do not suggest
        }
    }
except:
    convert_config = {}


# BertModel_module  加载官方bert模型和fastertransformer模型
#如果是正常pb, 可以直接使用官方modeling 模块 import modeling

def load_model_tensor(bert_config_file,max_seq_len,num_labels):
    BertModel_module = tf2pb.get_modeling(convert_config)
    if BertModel_module is None:
        raise Exception('tf2pb get_modeling failed')
    bert_config = BertModel_module.BertConfig.from_json_file(bert_config_file)

    def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, num_labels, use_one_hot_embeddings):
        """Creates a classification model."""
        model = BertModel_module.BertModel(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=use_one_hot_embeddings)

        output_layer = model.get_pooled_output()
        hidden_size = output_layer.shape[-1].value
        output_weights = tf.get_variable(
            "output_weights", [num_labels, hidden_size],
            dtype="float32",
            initializer=tf.truncated_normal_initializer(stddev=0.02))
        output_bias = tf.get_variable(
            "output_bias", [num_labels],
            dtype="float32",
            initializer=tf.zeros_initializer())
        logits = tf.matmul(output_layer, output_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)
        probabilities = tf.nn.softmax(logits, axis=-1)
        return probabilities

    input_ids = tf.placeholder(tf.int32, (None, max_seq_len), 'input_ids')
    input_mask = tf.placeholder(tf.int32, (None, max_seq_len), 'input_mask')
    segment_ids = None
    # 这里简单使用分类,具体根据自己需求修改
    probabilities = create_model(bert_config, False, input_ids, input_mask, segment_ids, num_labels, False)
    save_config = {
        "input_tensor": {
            'input_ids': input_ids,
            'input_mask': input_mask
        },
        "output_tensor": {
            "pred_ids": probabilities
        },
    }
    return save_config

if __name__ == '__main__':

    # 训练ckpt权重
    weight_file = r'/home/tk/tk_nlp/script/ner/ner_output/bert/model.ckpt-2704'
    output_dir = r'/home/tk/tk_nlp/script/ner/ner_output/bert'

    bert_config_file = r'/data/nlp/pre_models/tf/bert/chinese_L-12_H-768_A-12/bert_config.json'
    if not os.path.exists(bert_config_file):
        raise Exception("bert_config does not exist")

    max_seq_len = 340
    num_labels = 16 * 4 + 1

    #normal pb
    pb_config = {
        "ckpt_filename": weight_file,  # 训练ckpt权重
        "save_pb_file": os.path.join(output_dir,'bert_inf.pb'),
    }
    #serving pb
    pb_serving_config = {
        'use':False,#默认注释掉保存serving模型
        "ckpt_filename": weight_file,  # 训练ckpt权重
        "save_pb_path_serving": os.path.join(output_dir,'serving'),  # tf_serving 保存模型路径
        'serve_option': {
            'method_name': 'tensorflow/serving/predict',
            'tags': ['serve'],
        }
    }

    if pb_config['save_pb_file'] and os.path.exists(pb_config['save_pb_file']):
        os.remove(pb_config['save_pb_file'])

    if pb_serving_config['use'] and pb_serving_config['save_pb_path_serving'] and os.path.exists(pb_serving_config['save_pb_path_serving']):
        shutil.rmtree(pb_serving_config['save_pb_path_serving'])


    def convert2pb(is_save_serving):
        def create_network_fn():
            save_config = load_model_tensor(bert_config_file=bert_config_file,max_seq_len=max_seq_len,num_labels=num_labels)
            save_config.update(pb_serving_config if is_save_serving else pb_config)
            return save_config

        if not is_save_serving:
            ret = tf2pb.freeze_pb(create_network_fn)
            if ret ==0:
                tf2pb.pb_show(pb_config['save_pb_file'])
            else:
                print('tf2pb.freeze_pb failed ',ret)
        else:
            ret = tf2pb.freeze_pb_serving(create_network_fn)
            if ret ==0:
                tf2pb.pb_serving_show(pb_serving_config['save_pb_path_serving'],pb_serving_config['serve_option']['tags'])  # 查看
            else:
                print('tf2pb.freeze_pb_serving failed ',ret)

    convert2pb(is_save_serving = False)
    if pb_serving_config['use']:
        convert2pb(is_save_serving = True)

3. h5 convert pb

import sys
import tensorflow as tf
import tf2pb
import os
from keras.models import Model,load_model

# bert_model is construct by your src code
weight_file = os.path.join(output_dir, 'best_model.h5')
bert_model.load_weights(weight_file , by_name=False)
# or bert_model = load_model(weight_file)

print(bert_model.inputs)
#modify output name
pred_ids = tf.identity(bert_model.output, "pred_ids")



config = {
    'model': bert_model,# the model your trained
    'input_tensor' : {
        "Input-Token": bert_model.inputs[0], # Tensor such as  bert.Input[0]
        "Input-Segment": bert_model.inputs[1], # Tensor such as  bert.Input[0]
    },
    'output_tensor' : {
        "pred_ids": pred_ids, # Tensor output tensor
    },
    'save_pb_file': r'/root/save_pb_file.pb', # pb filename
}

if os.path.exists(config['save_pb_file']):
    os.remove(config['save_pb_file'])
#直接转换
tf2pb.freeze_keras_pb(config)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

tf2pb-0.2.0-py3-none-any.whl (92.3 kB view details)

Uploaded Python 3

File details

Details for the file tf2pb-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: tf2pb-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 92.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.5.0 importlib_metadata/4.8.2 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.8.9

File hashes

Hashes for tf2pb-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 cc7f0b0d73794b28f1d06e412136c6e7af132051828c718b7a45ba7a2cafa21a
MD5 97c2424007511a680a535da9ac155dea
BLAKE2b-256 410dc4723ce2d480aac0e3b800d72247af776813966147f3627b928b659f282d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page