Skip to main content

Thrift for Deep Learning

Project description

What is it?

thrift4DL is appreviation of Thrift for Deep Learning which is built to be batching server in production


  • Easy to use: require a little bit knowledge and skills to start a Deep Learning server in production mode
  • Scalable: scale easily on multiple GPU and multiple clients without worrying about concurrency


You can install the package from Pypi: pip install thrift4DL

or from source:

git clone

python install


Before you go further, please take a look at example folder.

Server Side

class ServerHandler(VisionHandler):
    def get_env(self, gpu_id, mem_fraction):
        This is function to initialize environments, import packages, ...
        gpu_id: int or str
            GPU ID in int or str format
        mem_fraction: float
            Memory fraction to occupy in GPU
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
        import tensorflow as tf
        return tf

    def get_model(self, model_path, env_params):
        tf = env_params
        model = MnistModel(tf, model_path)
        return model

    def preprocess(self, model, input):
        """Preprocess for an input"""
        input = json.loads(input)
        input = decode_image(input['image'])
        input = model.preprocessing(input)
        assert input.shape == (1, 28, 28, 1), ValueError("Wrong input shape")
        return input

    def predict(self, model, input):
        """Given a batch of input to predict"""
        input = np.vstack(input)
        result = model.predict(input)
        results = []
        for val in result:
            pred_num = np.argmax(val)
            pred_score = np.max(val)
            results.append({"pred_num": int(pred_num),
                            "pred_score": float(pred_score)})
        return results

server = TModelPoolServer(host='', port='9090',
                          model_path='mnist.pb', gpu_ids=[6]*NUM_MODELS,
                          batch_infer_size=100, batch_group_timeout=1,


The framework supports 2 ways to request to the server.

  • Using RESTFul API:
import requests
import numpy as np
import json

def test_api(image_hex, host, port):
    url = ''
    js = {"image": image_hex}
    response =, json=js)
    return response
  • TCP/IP:

You can use thrift to gen your code in your required languages by using thrift_file in thrift4DL/thrift_file/service3.thrift and use it as usual.

Or if you use python, there are available client class written in example folder.

from thrift4DL.client import VisionClient
import numpy as np
from zaailabcorelib.ztools import encode_image, load_image

# Initialize a client
client = VisionClient(host='', port='9090')

# Load and encode image to hex
img_arr = load_image('../example/mnist/num5.png')
img_hex = encode_image(img_arr)

# Request to server
js = json.dumps({"image": img_hex})
result = client.predict(js)


Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for thrift4DL, version 3.2.0
Filename, size File type Python version Upload date Hashes
Filename, size thrift4DL-3.2.0-py3.6.egg (46.8 kB) File type Egg Python version 3.6 Upload date Hashes View hashes
Filename, size thrift4DL-3.2.0-py3-none-any.whl (20.6 kB) File type Wheel Python version py3 Upload date Hashes View hashes

Supported by

Elastic Elastic Search Pingdom Pingdom Monitoring Google Google BigQuery Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page