nn_sdk推理tf1 tf2 pb nlp模型 , input tensor[input_ids , input_mask], output tensor[pred_ids]
Project description
nn-sdk是一个基于tensorflow(v1 ,v2)和onnx神经网络推理开发包
# -*- coding: utf-8 -*-
from nn_sdk.py_engine_csdk import csdk_object
'''
前言:
当前支持开发语言c/c++,python,java
当前支持推理引擎tensorflow(v1,v2) onnxruntime
当前支持多子图,支持图多输入多输出,支持tensorflow 1 pb , tensorflow 2 pb , tensorflow ckpt
python (demo.py) , c包 (nn_sdk.h) , java包 (nn_sdk.java)
'''
'''
python 推理demo
config 字段介绍:
engine: 推理引擎 0: tensorflow , 1: onnx
log_level: 日志类型 0 fatal , 2 error , 4 info , 8 debug
model_type: tensorflow时有效, 0 pb format if 1 ckpt format
ConfigProto: tensorflow时有效
graph_inf_version: tensorflow version [0,1] or onnxruntime 1
graph: 多子图配置
node: 例子: tensorflow 1 input_ids:0 , tensorflow 2: input_ids , onnx: input_ids
data_type: 节点的类型根据模型配置,支持 int int64 long longlong float
shape: 节点尺寸
python 接口可以忽视 data_type,shape字段 ,如 {"node":"input_ids:0"}
java 和 c 包不可缺少 data_type,shape字段
'''
config = {
"model_dir": r'E:/algo_text/nn_csdk/nn_csdk/py_test_ckpt/model.ckpt',
"log_level": 4,
'engine':1,
"model_type": 1,
"ConfigProto": {
"log_device_placement": False,
"allow_soft_placement": True,
"gpu_options": {
"allow_growth": True
},
},
"graph_inf_version": 1,
"graph": [
{
"input": [
{"node":"input_ids:0", "data_type":"float", "shape":[1, 256]},
{"node":"input_mask:0", "data_type":"float", "shape":[1, 256]}
],
"output": [
{"node":"input_ids:0", "data_type":"float", "shape":[1, 256]},
],
}
]}
seq_length = 256
input_ids = [[1.] * seq_length]
input_mask = [[1] * seq_length]
sdk_inf = csdk_object(config)
if sdk_inf.valid():
net_stage = 0
ret, out = sdk_inf.process(net_stage, input_ids,input_mask)
print(ret)
print(out)
sdk_inf.close()
nn-sdk java包demo,配置参考python
package nn_sdk;
//包名必须是nn_sdk
public class nn_sdk {
//输入输出内存节点,名字跟图配置一样,根据图对象修改此属性。
public float [] input_ids = null;//推理图的输入,
//public float [] input_mask = null;//推理图的输入,
public float[] pred_ids = null;//推理的结果保存
public nn_sdk() {
//初始化配置图的输入输出内存
input_ids = new float[1 * 20];
pred_ids = new float[1 * 20];
for(int i =0;i<20;i++) {
input_ids[i] = 1;
pred_ids[i] = 0;
}
}
//推理函数
public native static int sdk_init_cc();
public native static int sdk_uninit_cc();
public native static long sdk_new_cc(String json);
public native static int sdk_delete_cc(long handle);
public native static int sdk_process_cc(long handle, int net_state, nn_sdk data);
static {
//动态库的绝对路径windows是 engine_csdk.pyd , linux是 engine_csdk.so
System.load("E:\\algo_text\\nn_csdk\\build_py36\\Release\\engine_csdk.pyd");
}
public static void main(String[] args){
System.out.println("java main...........");
nn_sdk instance = new nn_sdk();
sdk_init_cc();
String json = "{" + "\"model_dir\": \"E:/algo_text/nn_csdk/nn_csdk/py_test_ckpt/model.ckpt\"," + "\n" +
"\"log_level\":4," + "\n" +
"\"engine\":0," + "\n" +
"\"model_type\":1, " + "\n" +
"\"ConfigProto\": {" + "\n" +
"\"log_device_placement\":false," + "\n" +
"\"allow_soft_placement\":true," + "\n" +
"\"gpu_options\":{\"allow_growth\": true}" + "\n" +
"}," + "\n" +
"\"graph_inf_version\": 1," + "\n" +
"\"graph\": [" + "\n" +
"{" + "\n" +
"\"input\": [{\"node\":\"input_ids:0\", \"data_type\":\"float\", \"shape\":[1, 20]}]," + "\n" +
"\"output\" : [{\"node\":\"pred_ids:0\", \"data_type\":\"float\", \"shape\":[1, 20]}]" + "\n" +
"}" + "\n" +
"]" + "\n" +
"}";
System.out.println(json);
long handle = sdk_new_cc(json);
System.out.println(handle);
int code = sdk_process_cc(handle,0,instance);
System.out.printf("sdk_process_cc %d \n" ,code);
if(code == 0) {
for(int i = 0;i<20 ; i++) {
System.out.printf("%f ",instance.pred_ids[i]);
}
System.out.println();
}
sdk_delete_cc(handle);
sdk_uninit_cc();
System.out.println("end");
}
}
nn-sdk c包
#ifndef __CC_SDK_H__
#define __CC_SDK_H__
#include <stdio.h>
#ifdef _WIN32
#ifdef _CC_SDK_EXPORT
#define CC_SDK_EXPORT _declspec( dllexport )
#else
#define CC_SDK_EXPORT _declspec(dllimport)
#endif
#else
#define CC_SDK_EXPORT
#endif // _WIN32
#ifdef __cplusplus
extern "C" {
#endif
typedef long long SDK_HANDLE_CC;
CC_SDK_EXPORT int sdk_init_cc();
CC_SDK_EXPORT int sdk_uninit_cc();
CC_SDK_EXPORT SDK_HANDLE_CC sdk_new_cc(const char* json);
CC_SDK_EXPORT int sdk_delete_cc(SDK_HANDLE_CC handle);
CC_SDK_EXPORT int sdk_process_ex_cc(SDK_HANDLE_CC handle, void** final_result, int net_stage, void**input_buffer_list);
#ifdef __cplusplus
}
#endif
#endif
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
No source distribution files available for this release.See tutorial on generating distribution archives.
Built Distributions
Close
Hashes for nn_sdk-1.3.2-cp38-cp38-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1309e9c3a6db6799a8b860364e52f330d5e81ff0bfd48c1d2b4f45173de62405 |
|
MD5 | 4662e673ef81c5843c67d6615e3ef007 |
|
BLAKE2b-256 | 9aafe987e57448fd81dbc72d344b641430beb06113e48cafa0e7a591a35eaab2 |
Close
Hashes for nn_sdk-1.3.2-cp38-cp38-manylinux2010_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1b370c18e98f2c46d8f29e4ef3ef3fac864e6ebbe404eca60cf072f491382a92 |
|
MD5 | c7bda042b02369439398ee73e1e1a271 |
|
BLAKE2b-256 | dc54484ceabf798ef9c43628dc066c0cf6fd345f11b7a621fdaf745af6b9f43f |
Close
Hashes for nn_sdk-1.3.2-cp37-cp37m-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 11524035aaf229c7e69f2daf1c351af0f60b878507adc86ef0eabbf9028ad1ac |
|
MD5 | cb9ffd500bbe4df175a0ca2662a534ef |
|
BLAKE2b-256 | 2498cc94d5b68a9356d75ce71ca96106e27ea66ab5d617e0cb2adeab8b6f6a3f |
Close
Hashes for nn_sdk-1.3.2-cp37-cp37m-manylinux2010_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8f3d969f7f015d4100174403e5b7bbbca556e15f92f85dde62c5c58fe2f0152d |
|
MD5 | a759fe3cb0b9787cb5f2009a7cb19413 |
|
BLAKE2b-256 | 6387bb529833b1c0bb0fe5c2a85738df2271151ef1ecfe035dcc1cefad0abc14 |
Close
Hashes for nn_sdk-1.3.2-cp36-cp36m-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 292d091e2f9091de38be2fb24ea6c700f1d7bb3fbcb8095be4ae23265936380b |
|
MD5 | ae6b11f89ed794af715337bc7f38cd71 |
|
BLAKE2b-256 | 6b75dd09643b89f5079a6a09d9e30561be81906734f642e8d9bf045424eecf83 |
Close
Hashes for nn_sdk-1.3.2-cp36-cp36m-manylinux2010_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e32ee69336f2338672b60491f6b897a01983e4f2d5df22fbb665674e6a267fb9 |
|
MD5 | 3ff41a1baaaaa90da31cac9dff8b9471 |
|
BLAKE2b-256 | d76cea1b7d1ab684f15f00002868ad7b4c3880e471ae72cb87605e8bcaa2585e |