Skip to main content

nn_sdk推理tf1 tf2 pb nlp模型 , input tensor[input_ids , input_mask], output tensor[pred_ids]

Project description

nn-sdk是一个基于tf1 tf2神经网络推理开发包

# -*- coding: utf-8 -*-
from nn_sdk.py_tf_csdk import csdk_object
'''
    python 推理demo
    支持 python (demo.py) , c (nn_sdk.h) , java (nn_sdk.java)
    python 使用如下
    支持多子图,支持图多输入多输出.
    支持tensorflow 1 pb , tensorflow 2 pb , tensorflow ckpt
    net_stage 推理子图序号0-n
'''
config = {
    "model_dir": r'./model.ckpt',
    "log_level": 4, # 0 fatal , 2 error , 4 info , 8 debug
    "model_type": 1,  # 0 pb format   if 1 ckpt format
    "ConfigProto": {
        "log_device_placement": False,
        "allow_soft_placement": True,
        "gpu_options": {
            "allow_growth": True
        },
    },
    "graph_inf_version": 1,  # the format of tensorflow pb model [1,2]
    "graph": [
        {
            #tf 1 node sample "input_ids:0" ,  tf2 sample "input_ids"  data_type  int int64 long longlong float double
            #python 接口可以忽视 data_type,shape字段,在 c接口会使用,例子 python {"node":"input_ids:0"}
            "input": [
                {"node":"input_ids:0", "data_type":"float", "shape":[1, 256]},
                {"node":"input_mask:0", "data_type":"float", "shape":[1, 256]}
            ],
            "output": [
                {"node":"input_ids:0", "data_type":"float", "shape":[1, 256]},
            ],
        }
    ]}

seq_length = 256
input_ids = [[10.] * seq_length]
input_mask = [[1] * seq_length]
sdk_inf = csdk_object(config)
if sdk_inf.valid():
    net_stage = 0
    ret, out = sdk_inf.process(net_stage, input_ids,input_mask)
    print(ret)
    print(out)
    sdk_inf.close()

nn-sdk java

package nn_sdk;
//包名必须是nn_sdk
public class nn_sdk {
    //输入输出内存节点,名字跟图配置一样,根据图对象修改此属性。 
	public float [] input_ids = null;//推理图的输入,
	//public float [] input_mask = null;//推理图的输入,
	public float[] pred_ids =   null;//推理的结果保存

	public nn_sdk() {
		//初始化配置图的输入输出内存
		input_ids = new float[1 * 20];
		pred_ids =  new float[1 * 20];
		for(int i =0;i<20;i++) {
			input_ids[i] = 1;
			pred_ids[i] = 0;
		}
	}

	//推理函数
	public native static int  sdk_init_cc();
	public native static int  sdk_uninit_cc();
	public native static long sdk_new_cc(String json);
	public native static int  sdk_delete_cc(long handle);
	public native static int sdk_process_cc(long handle, int net_state, nn_sdk data);

	static {
		//动态库的绝对路径windows是tf_csdk.pyd , linux是 tf_csdk.so
		System.load("E:\\algo_text\\nn_csdk\\build_py36\\Release\\tf_csdk.pyd");  
	}

	public static void main(String[] args){  
		System.out.println("java main...........");

	   nn_sdk instance = new nn_sdk();
	   sdk_init_cc();

	   String json =  "{" + "\"model_dir\": \"E:/algo_text/nn_csdk/nn_csdk/py_test_ckpt/model.ckpt\"," + "\n" +
	   "\"log_level\":4," + "\n" +
	   "\"model_type\":1, " + "\n" +
	   "\"ConfigProto\": {" + "\n" +
				 "\"log_device_placement\":false," + "\n" +
	   "\"allow_soft_placement\":true," + "\n" +
				 "\"gpu_options\":{\"allow_growth\": true}" + "\n" +
	   "}," + "\n" +
				 "\"graph_inf_version\": 1," + "\n" +
	   "\"graph\": [" + "\n" +
				 "{" + "\n" +
	   "\"input\": [{\"node\":\"input_ids:0\", \"data_type\":\"float\", \"shape\":[1, 20]}]," + "\n" +
				 "\"output\" : [{\"node\":\"pred_ids:0\", \"data_type\":\"float\", \"shape\":[1, 20]}]" + "\n" +
	"}" + "\n" +
				 "]" + "\n" +
	"}";


	  System.out.println(json);

	  long handle = sdk_new_cc(json);
	  System.out.println(handle);

	  int code = sdk_process_cc(handle,0,instance);
	  System.out.printf("sdk_process_cc %d \n" ,code);
	  if(code == 0) {
		  for(int i = 0;i<20 ; i++) {
			  System.out.printf("%f ",instance.pred_ids[i]);
		  }
		  System.out.println();
	  }
	  sdk_delete_cc(handle);
	   sdk_uninit_cc();
	   System.out.println("end");
	}
}

nn-sdk c

#ifndef __CC_SDK_H__
#define __CC_SDK_H__

#include <stdio.h>


#ifdef _WIN32
#ifdef _CC_SDK_EXPORT
#define CC_SDK_EXPORT _declspec( dllexport )
#else
#define CC_SDK_EXPORT _declspec(dllimport)
#endif
#else
#define CC_SDK_EXPORT
#endif // _WIN32



#ifdef __cplusplus
extern "C" {
#endif

	typedef long long SDK_HANDLE_CC;

	CC_SDK_EXPORT int sdk_init_cc();

	CC_SDK_EXPORT int sdk_uninit_cc();

	CC_SDK_EXPORT SDK_HANDLE_CC sdk_new_cc(const char* json);

	CC_SDK_EXPORT int sdk_delete_cc(SDK_HANDLE_CC handle);

	CC_SDK_EXPORT int sdk_process_ex_cc(SDK_HANDLE_CC handle, void** final_result, int net_stage, void**input_buffer_list);

#ifdef __cplusplus
}
#endif


#endif

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

nn_sdk-1.2.9-cp38-cp38-win_amd64.whl (62.3 kB view hashes)

Uploaded CPython 3.8 Windows x86-64

nn_sdk-1.2.9-cp38-cp38-manylinux2010_x86_64.whl (158.1 kB view hashes)

Uploaded CPython 3.8 manylinux: glibc 2.12+ x86-64

nn_sdk-1.2.9-cp37-cp37m-win_amd64.whl (62.1 kB view hashes)

Uploaded CPython 3.7m Windows x86-64

nn_sdk-1.2.9-cp37-cp37m-manylinux2010_x86_64.whl (157.9 kB view hashes)

Uploaded CPython 3.7m manylinux: glibc 2.12+ x86-64

nn_sdk-1.2.9-cp36-cp36m-win_amd64.whl (62.0 kB view hashes)

Uploaded CPython 3.6m Windows x86-64

nn_sdk-1.2.9-cp36-cp36m-manylinux2010_x86_64.whl (157.8 kB view hashes)

Uploaded CPython 3.6m manylinux: glibc 2.12+ x86-64

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page