Skip to main content

A collection of helper methods to simplify optimization and inference of Huggingface Transformers-based models

Project description

Transformers Inference Toolkit

PyPI

🤗 Transformers library provides great API for manipulating pre-trained NLP (as well as CV and Audio-related) models. However, preparing 🤗 Transformers models for use in production usually requires additional effort. The purpose of transformers-inference-toolkit is to get rid of boilerplate code and to simplify automatic optimization and inference process of Huggingface Transformers models.

Installation

Using pip:

pip install transformers-inference-toolkit

Optimization

The original 🤗 Transformers library includes transformers.onnx package, which can be used to convert PyTorch or TensorFlow models into ONNX format. This Toolkit extends this functionality by giving the user an opportunity to automatically optimize ONNX model graph - this is similar to what 🤗 Optimum library does, but 🤗 Optimum currently has limited support for locally stored pre-trained models as well as for models of less popular architectures (for example, MPNet).

Aside from ONNX conversion the Toolkit also supports resaving PyTorch models with half-precision and setting up DeepSpeed Inference.

Prerequisite

The Toolkit expects your pretrained model (in PyTorch format) and tokenizer to be saved (using save_pretrained() method) inside a common parent directory in model and tokenizer folders respectively. This is how a file structure of toxic-bert model should look like:

toxic-bert
├── model
│   ├── config.json
│   └── pytorch_model.bin
└── tokenizer
    ├── merges.txt
    ├── special_tokens_map.json
    ├── tokenizer_config.json
    └── vocab.json

How to use

Most of the popular Transformer model architectures (like BERT and its variations) can be converted with a single command:

from transformers_inference_toolkit import (
    Feature,
    OnnxModelType,
    OnnxOptimizationLevel,
    optimizer,
)

optimizer.pack_onnx(
    input_path="toxic-bert",
    output_path="toxic-bert-optimized",
    feature=Feature.SEQUENCE_CLASSIFICATION,
    for_gpu=True,
    fp16=True,
    optimization_level=OnnxOptimizationLevel.FULL,
)

If your model architecture is not supported out-of-the-box (described here) you can try writing a custom OnnxConfig class:

from collections import OrderedDict
from transformers.onnx import OnnxConfig

class MPNetOnnxConfig(OnnxConfig):
    @property
    def default_onnx_opset(self):
        return 14

    @property
    def inputs(self):
        dynamic_axis = {0: "batch", 1: "sequence"}
        return OrderedDict(
            [
                ("input_ids", dynamic_axis),
                ("attention_mask", dynamic_axis),
            ]
        )

optimizer.pack_onnx(
    input_path="all-mpnet-base-v2",
    output_path="all-mpnet-base-v2-optimized",
    feature=Feature.DEFAULT,
    custom_onnx_config_cls=MPNetOnnxConfig,
)

ONNX is not the only option, it is also possible to resave the model for future inference simply using PyTorch (optimizer.pack_transformers() method, force_fp16 argument to save in half-precision) or DeepSpeed Inference (optimizer.pack_deepspeed() method):

optimizer.pack_deepspeed(
    input_path="gpt-neo",
    output_path="gpt-neo-optimized",
    feature=Feature.CAUSAL_LM,
    replace_with_kernel_inject=True,
    mp_size=1,
)

After calling optimizer methods the model and tokenizer would be saved at output_path. The output directory will also contain metadata.json file that is necessary for the Predictor object (described below) to correctly load the model:

toxic-bert-optimized
├── metadata.json
├── model
│   ├── config.json
│   └── model.onnx
└── tokenizer
    ├── special_tokens_map.json
    ├── tokenizer.json
    └── tokenizer_config.json

Prediction

After model and tokenizer are packaged using one of the optimizer methods, it is possible to initialize a Predictor object:

>>> from transformers_inference_toolkit import Predictor
>>> 
>>> predictor = Predictor("toxic-bert-optimized", cuda=True)
>>> print(predictor("I hate this!"))
{'logits': array([[ 0.02940369, -7.0195312 , -4.7890625 , -6.0664062 , -5.625     ,
        -6.09375   ]], dtype=float32)}

The Predictor object can be simply called with tokenizer arguments (similar to 🤗 Transformers pipelines, return_tensors argument can be omitted, padding and truncation are True by default). For text generation tasks Predictor.generate() method (with generation arguments) can be used:

>>> predictor = Predictor("gpt-neo-optimized", cuda=True)
>>> predictor.generate(
...     "Tommy: Hi Mark!",
...     do_sample=True,
...     top_p=0.9,
...     num_return_sequences=3,
...     max_new_tokens=5,
... )
['Tommy: Hi Mark!\nMadelyn: Hello', 'Tommy: Hi Mark! It’s so', 'Tommy: Hi Mark! How are you?\n']

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformers_inference_toolkit-0.1.1.tar.gz (15.1 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file transformers_inference_toolkit-0.1.1.tar.gz.

File metadata

File hashes

Hashes for transformers_inference_toolkit-0.1.1.tar.gz
Algorithm Hash digest
SHA256 50a491efffe49a1fc04d1adaaf8b34be6de094aa733cad4378b87c66b98630a9
MD5 7f4ebed09ccd98fd04eafbe5dee32cc1
BLAKE2b-256 dc1200f97d1672cfb7b4f835a04afc9ec78150866e0d891fe68ac2a0916cc282

See more details on using hashes here.

File details

Details for the file transformers_inference_toolkit-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for transformers_inference_toolkit-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a68bd2e19154a7a587c23dc3858b9ce53b9b9c411fa350b6ee349e101d93540e
MD5 f3a258118236fb6dd01b32b9800653d6
BLAKE2b-256 6eb683fca844d24d8203de77844f55d028de57a3b40f996b7c3cc0e252a8ddd0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page