Skip to main content

nn-sdk是一个基于tensorflow(v1 ,v2),onnx和tensorrt神经网络推理开发包

Project description

nn-sdk是一个基于tensorflow(v1 ,v2),onnx和tensorrt神经网络推理开发包

'''
    前言:
        当前支持开发语言c/c++,python,java
        当前支持推理引擎tensorflow(v1,v2) onnxruntime tensorrt,fasttext 注:tensorrt 7,8测试通过(建议8),目前tensorrt只支持linux系统
        当前支持多子图,支持图多输入多输出, 支持pb [tensorflow 1,2] , ckpt [tensorflow] , trt [tensorrt] , fasttext
        当前支持fastertransformer pb [32精度 相对于传统tf,加速1.9x] ,安装 pip install tf2pb  , 进行模型转换
        tf2pb pb模型转换参考: https://pypi.org/project/tf2pb
        模型加密参考test_aes.py,目前支持tensorflow 1 pb模型 , onnx模型 , tensorrt fasttext模型加密
        推荐环境ubuntu16 ubuntu18  ubuntu20 centos7 centos8 windows系列
        python (test_py.py) , c语言 (test.c) , java语言包 (nn_sdk.java)
        使用过程中遇到问题或者有建议可加qq group: 759163831
        更多使用参见: https://github.com/ssbuild

    python 推理demo
    config 字段介绍:
        aes: 加密参考test_aes.py
        engine: 推理引擎 0: tensorflow , 1: onnx , 2: tensorrt 3: fasttext
        log_level: 日志类型 0 fatal , 2 error , 4 info , 8 debug
        model_type: tensorflow 模型类型, 0 pb format , 1 ckpt format
        fastertransformer:  fastertransformer算子选项, 参考 https://pypi.org/project/tf2pb
        ConfigProto: tensorflow 显卡配置
        device_id: GPU id
        engine_version: 推理引擎主版本 tf 0,1  tensorrt 7 或者 8 , fasttext 0需正确配置
        graph: 多子图配置 
            node: 例子: tensorflow 1 input_ids:0 ,  tensorflow 2: input_ids , onnx: input_ids
            data_type: 节点的类型根据模型配置,对于c++/java支持 int int64 long longlong float double str
            shape:  尺寸维度
    更新详情:
    2021-10-7 增加 fasttext 向量和标签推理

'''
# -*- coding: utf-8 -*-
from nn_sdk import *
config = {
    "model_dir": r'/root/model.ckpt',
    "aes":{
        "use":False,
        "key":bytes([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]),
        "iv":bytes([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]),
    },
    "log_level": 4,# fatal 1 , error 2 , info 4 , debug 8
    'engine':0, # 0 tensorflow,  1 onnx , 2  tensorrt , 3 fasttext
    "device_id": 0,
    'tf':{
        #tensorflow2 ConfigProto无效
        "ConfigProto": {
            "log_device_placement": False,
            "allow_soft_placement": True,
            "gpu_options": {
                "allow_growth": True
            },
            "graph_options":{
                "optimizer_options":{
                    "global_jit_level": 1
                }
            },
        },
        "engine_version": 1, # tensorflow版本
        "model_type": 1,# 0 pb , 1 ckpt
        "saved_model":{ # 当model_type为pb模型有效, 普通pb use=False , 如果是saved_model冻结模型 , 则需启用use并且配置tags
            'use': False, # 是否启用saved_model
            'tags': ['serve'],
            'signature_key': 'serving_default',
        },
        "fastertransformer":{
            "use": False,
            "cuda_version":"11.3", #当前依赖 tf2pb,支持10.2 11.3 ,
        }
    },
    'onnx':{
        "engine_version": 1,# onnxruntime 版本
    },
    'trt':{
        "engine_version": 8,# tensorrt 版本
        "enable_graph": 0,
    },
    'fasttext': {
        "engine_version": 0,# fasttext主版本
        "threshold":0, # 预测k个标签的阈值
        "k":1, # 预测k个标签
        "dump_label": 1, #输出内部标签,用于上层解码
        "predict_label": 1, #获取预测标签 1  , 获取向量  0
    },
    "graph": [
        {
            # 对于Bert模型 shape [max_batch_size,max_seq_lenth],
            # 其中max_batch_size 用于c++ java开辟输入输出缓存,输入不得超过max_batch_size,对于python没有作用,取决于上层用户真实输入
            # python限制max_batch_size 在上层用户输入做
            # 对于fasttext node 对应name可以任意写,但不能少
            "input": [
                {"node":"input_ids:0", "data_type":"int64", "shape":[1, 256]},
                {"node":"input_mask:0", "data_type":"int64", "shape":[1, 256]}
            ],
            "output": [
                {"node":"pred_ids:0", "data_type":"int64", "shape":[1, 256]},
            ],
        }
    ]}

seq_length = 256
input_ids = [[1] * seq_length]
input_mask = [[1] * seq_length]
sdk_inf = csdk_object(config)
if sdk_inf.valid():
    net_stage = 0
    ret, out = sdk_inf.process(net_stage, input_ids,input_mask)
    print(ret)
    print(out)
    sdk_inf.close()

nn-sdk.java java包demo,配置参考python

package nn_sdk;
//包名必须是nn_sdk
public class nn_sdk {
    //输入输出内存节点,名字跟图配置一样,根据图对象修改此属性。 
	public float [] input_ids = null;//推理图的输入,
	//public float [] input_mask = null;//推理图的输入,
	public float[] pred_ids =   null;//推理的结果保存

	public nn_sdk() {
		//初始化配置图的输入输出内存
		input_ids = new float[1 * 20];
		pred_ids =  new float[1 * 20];
		for(int i =0;i<20;i++) {
			input_ids[i] = 1;
			pred_ids[i] = 0;
		}
	}
	
	//推理函数
	public native static int  sdk_init_cc();
	public native static int  sdk_uninit_cc();
	public native static long sdk_new_cc(String json);
	public native static int  sdk_delete_cc(long handle);
	public native static int sdk_process_cc(long handle, int net_state, nn_sdk data);
	
	static {
		//动态库的绝对路径windows是 engine_csdk.pyd , linux是 engine_csdk.so
		System.load("E:\\algo_text\\nn_csdk\\build_py36\\Release\\engine_csdk.pyd");
	}

	public static void main(String[] args){  
		System.out.println("java main...........");
		
	   nn_sdk instance = new nn_sdk();
	   sdk_init_cc();
	   //配置参考 python
	   String json =  "{" + "\"model_dir\": \"E:/algo_text/nn_csdk/nn_csdk/py_test_ckpt/model.ckpt\"," + "\n" +
	   "\"log_level\": 4," + "\n" +
	   "\"engine\": 0," + "\n" +
	   "\"device_id\": 0," + "\n" +
	   "\"tf\":{ " + "\n" +
           "\"ConfigProto\": {" + "\n" +
                     "\"log_device_placement\":false," + "\n" +
                      "\"allow_soft_placement\":true," + "\n" +
                     "\"gpu_options\":{\"allow_growth\": true}" + "\n" +
           "}," + "\n" +
			"\"engine_version\": 1," + "\n" +
			"\"model_type\":1, " + "\n" +
	    "}" + "\n" +
	   "\"graph\": [" + "\n" +
				    "{" + "\n" +
                        "\"input\": [{\"node\":\"input_ids:0\", \"data_type\":\"float\", \"shape\":[1, 20]}]," + "\n" +
                        "\"output\" : [{\"node\":\"pred_ids:0\", \"data_type\":\"float\", \"shape\":[1, 20]}]" + "\n" +
	                "}" + "\n" +
		        "]" + "\n" +
	    "}";
				

	  System.out.println(json);
	   
	  long handle = sdk_new_cc(json);
	  System.out.println(handle);
	   
	  int code = sdk_process_cc(handle,0,instance);
	  System.out.printf("sdk_process_cc %d \n" ,code);
	  if(code == 0) {
		  for(int i = 0;i<20 ; i++) {
			  System.out.printf("%f ",instance.pred_ids[i]);
		  }
		  System.out.println();
	  }
	  sdk_delete_cc(handle);
	   sdk_uninit_cc();
	   System.out.println("end");
	}
}

nn-sdk c包

#ifndef __CC_SDK_H__
#define __CC_SDK_H__
#include <stdio.h>
#ifdef _WIN32
#ifdef _CC_SDK_EXPORT
#define CC_SDK_EXPORT _declspec( dllexport )
#else
#define CC_SDK_EXPORT _declspec(dllimport)
#endif
#else
#define CC_SDK_EXPORT
#endif

#ifdef __cplusplus
extern "C" {
#endif
	typedef long long SDK_HANDLE_CC;

	CC_SDK_EXPORT int sdk_init_cc();

	CC_SDK_EXPORT int sdk_uninit_cc();

	CC_SDK_EXPORT SDK_HANDLE_CC sdk_new_cc(const char* json);

	CC_SDK_EXPORT int sdk_delete_cc(SDK_HANDLE_CC handle);

	CC_SDK_EXPORT int sdk_process_cc(SDK_HANDLE_CC handle, void** final_result, int net_stage, void** input_buffer_list);

#ifdef __cplusplus
}
#endif


#endif

#include <stdio.h>
#include "nn_sdk.h"

int main(){
    if (0 != sdk_init_cc()) {
		return -1;
	}
    printf("配置参考 python.........\n");
	const char* json_data = "{\n\
    \"model_dir\": \"/root/model.ckpt\",\n\
    \"log_level\":8, \n\
     \"device_id\":0, \n\
    \"tf\":{ \n\
         \"ConfigProto\": {\n\
            \"log_device_placement\":0,\n\
            \"allow_soft_placement\":1,\n\
            \"gpu_options\":{\"allow_growth\": 1}\n\
        },\n\
        \"engine_version\": 1,\n\
        \"model_type\":1 ,\n\
    },\n\
    \"graph\": [\n\
        {\n\
            \"input\": [{\"node\":\"input_ids:0\", \"data_type\":\"float\", \"shape\":[1, 10]}],\n\
            \"output\" : [{\"node\":\"pred_ids:0\", \"data_type\":\"float\", \"shape\":[1, 10]}]\n\
        }\n\
    ]\n\
}";
	printf("%s\n", json_data);
	auto handle = sdk_new_cc(json_data);
	const int INPUT_NUM = 1;
	const int OUTPUT_NUM = 1;
	const int M = 1;
	const int N = 10;
	int *input[INPUT_NUM] = { 0 };
	float* result[OUTPUT_NUM] = { 0 };
	int element_input_size = sizeof(int);
	int element_output_size = sizeof(float);
	for (int i = 0; i < OUTPUT_NUM; ++i) {
		result[i] = (float*)malloc(M * N * element_output_size);
		memset(result[i], 0, M * N * element_output_size);
	}
	for(int i =0;i<INPUT_NUM;++i){
		input[i] = (int*)malloc(M * N * element_input_size);
		memset(input[i], 0, M * N * element_input_size);
		for (int j = 0; j < N; ++j) {
			input[i][j] = i;
		}
	}

	int code = sdk_process_cc(handle, (void**)result, 0, (void**)input);
	if (code == 0) {
		printf("result\n");
		for (int i = 0; i < N; ++i) {
			printf("%f ", result[0][i]);
		}
		printf("\n");
	}
	for (int i = 0; i < INPUT_NUM; ++i) {
		free(input[i]);
	}
	for (int i = 0; i < OUTPUT_NUM; ++i) {
		free(result[i]);
	}
	sdk_delete_cc(handle);
	sdk_uninit_cc();
	return 0;
}

aes加密示例

# -*- coding: UTF-8 -*-

import sys
#sys.path.append(r'E:\algo_text\nn_csdk\cmake_py36\Release')
from nn_sdk.engine_csdk import sdk_aes_encode_decode

def test_string():
    data1 = {
        "mode":0,# 0 加密 , 1 解密
        "key": bytes([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]),
        "iv": bytes([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]),
        "data": bytes([1,2,3,5,255])
    }

    code,encrypt = sdk_aes_encode_decode(data1)
    print(code,encrypt)

    data2 = {
        "mode":1,
        "key": bytes([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]),
        "iv": bytes([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]),
        "data": encrypt
    }

    code,plain = sdk_aes_encode_decode(data2)
    print(code,plain)

def test_encode_file(in_filename,out_filename):

    with open(in_filename,mode='rb') as f:
        data = f.read()
    if len(data) == 0 :
        return -1
    data1 = {
        "mode": 0,  # 0 加密 , 1 解密
        "key": bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
        "iv": bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
        "data": bytes(data)
    }

    code, encrypt = sdk_aes_encode_decode(data1)
    if code != 0:
        return code
    with open(out_filename, mode='wb') as f:
        f.write(encrypt)
    return code
def test_decode_file(in_filename,out_filename):
    with open(in_filename, mode='rb') as f:
        data = f.read()
    if len(data) == 0:
        return -1
    data1 = {
        "mode": 1,  # 0 加密 , 1 解密
        "key": bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
        "iv": bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
        "data": bytes(data)
    }

    code, plain = sdk_aes_encode_decode(data1)
    if code != 0:
        return code
    with open(out_filename, mode='wb') as f:
        f.write(plain)
    return code

test_encode_file(r'C:\Users\acer\Desktop\img\a.txt',r'C:\Users\acer\Desktop\img\a.txt.encode')
test_decode_file(r'C:\Users\acer\Desktop\img\a.txt.encode',r'C:\Users\acer\Desktop\img\a.txt.decode')

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

nn_sdk-1.8.2-cp39-cp39-win_amd64.whl (191.4 kB view hashes)

Uploaded CPython 3.9 Windows x86-64

nn_sdk-1.8.2-cp39-cp39-manylinux2014_x86_64.whl (7.3 MB view hashes)

Uploaded CPython 3.9

nn_sdk-1.8.2-cp38-cp38-win_amd64.whl (192.1 kB view hashes)

Uploaded CPython 3.8 Windows x86-64

nn_sdk-1.8.2-cp38-cp38-manylinux2014_x86_64.whl (7.3 MB view hashes)

Uploaded CPython 3.8

nn_sdk-1.8.2-cp37-cp37m-win_amd64.whl (193.1 kB view hashes)

Uploaded CPython 3.7m Windows x86-64

nn_sdk-1.8.2-cp37-cp37m-manylinux2014_x86_64.whl (7.4 MB view hashes)

Uploaded CPython 3.7m

nn_sdk-1.8.2-cp36-cp36m-win_amd64.whl (193.1 kB view hashes)

Uploaded CPython 3.6m Windows x86-64

nn_sdk-1.8.2-cp36-cp36m-manylinux2014_x86_64.whl (7.4 MB view hashes)

Uploaded CPython 3.6m

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page