boost inference speed of GPT models in transformers by onnxruntime
Project description
fastgpt
fastgpt 是什么
- fastgpt是一个基于transformers和onnxruntime的python库,可以无缝衔接的使用 onnxruntime 量化后的 transfromers GPT 模型做文本生成任务,提高推理速度、降低资源成本。
fastgpt 的背景
- GPT模型是通过序列文本预测下一个词的训练任务得到的预训练模型,可以在文本生成任务上达到很好的效果。
- transformers库是近些年最火的做预训练模型的 python 库,在其背后的社区,网友、组织分享开源了各式各样的预训练模型,尤其是截止 2022 年 6 月 23 日,社区的开源文本生成模型多达到5068个。
- onnx是由微软,亚马逊 ,Facebook 和 IBM 等公司共同开发的,针对机器学习所设计的开放式的文件格式,经过 onnxruntime 量化压缩的预训练模型,在 cpu 硬件上推理速度在各开源框架的对比中首屈一指。
- 然而,通过transformers官方的 onnx 接口转换、onnx 量化 API,却没有做好 GPT 模型转换的兼容问题,经常转换失败。而手动进行 onnx 转换需要自定义很多配置,对于新手不很友好。
- fastgpt库,就是为了无缝衔接 transformers 库调用 GPT 模型转换 onnx 格式推理,使用者可以在仅修改两行代码的情况下,使用 onnx 量化后的 GPT 模型,做 transformers 库的文本生成函数。
- 原 transformers 代码:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
- fastgpt 代码:
from fastgpt import CausalLMModelForOnnxGeneration
model = CausalLMModelForOnnxGeneration.from_pretrained("distilgpt2")
- 在 fastgpt 这一行代码中,会执行以下流程
- transformers hub 的模型下载
- pytorch 模型推理,输出 logits
- onnx 格式转换
- onnx 格式模型推理,输出 logits,进行对比差异
- onnx 量化
- onnx 量化格式模型推理,输出 logits,进行对比差异
- 把兼容 transformers 文本生成函数的 onnx 格式 GPT 模型,包装到 model 中
安装
pip install fastgpt
快速 demo
from transformers import AutoTokenizer
from fastgpt import CausalLMModelForOnnxGeneration
model = CausalLMModelForOnnxGeneration.from_pretrained("distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
prompt_text = "Natural language processing (NLP) is the ability of a computer program to understand human language as it is spoken and written"
input_ids = tokenizer(
prompt_text, return_tensors="pt", add_special_tokens=False
).input_ids
generated_ids = model.generate( # 这里完全兼容transformers的generate函数
input_ids,
max_length=64 + input_ids.shape[1],
decoder_start_token_id=tokenizer.cls_token_id,
eos_token_id=tokenizer.sep_token_id,
output_scores=True,
temperature=1,
repetition_penalty=1.0,
top_k=50,
top_p=0.9,
do_sample=True,
num_return_sequences=1,
length_penalty=2.0,
early_stopping=True,
)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
print("=" * 20)
fastgpt 的优点
- 兼容 transformers: 基于 transformers 库的文本生成函数,功能非常丰富。fastgpt 在 onnx 格式模型上,兼容该函数。
- 兼容 cache: 在文本生成的一个个 token 生成过程中的
past_key_value
需要在 GPT 模型上持续迭代输入,fastgpt 已经通过 onnx 格式做好衔接。 - 代码修改低成本:代码替换原版 transformers 仅需修改两行代码。
- onnx 格式占内存小:对于 distilgpt2 模型,torch 版
318MB
, onnx 量化版243MB
- cpu 上速度更快: 用时约降低 33%
生成速度评测(ms)
- 生成长度 4 评测
模型框架 | beam:1 | beam:2 | beam:3 | beam:4 |
---|---|---|---|---|
torch | 290.779 | 475.693 | 560.458 | 648.756 |
fastgpt | 195.265 | 292.272 | 378.933 | 466.14 |
- 生成长度 8 评测
模型框架 | beam:1 | beam:2 | beam:3 | beam:4 |
---|---|---|---|---|
torch | 482.199 | 817.065 | 905.646 | 1052.983 |
fastgpt | 341.735 | 471.028 | 583.264 | 713.009 |
- 生成长度 16 评测
模型框架 | beam:1 | beam:2 | beam:3 | beam:4 |
---|---|---|---|---|
torch | 878.338 | 1518.198 | 1619.336 | 1813.197 |
fastgpt | 635.157 | 838.787 | 1009.497 | 1210.047 |
- 生成长度 32 评测
模型框架 | beam:1 | beam:2 | beam:3 | beam:4 |
---|---|---|---|---|
torch | 1661.819 | 2854.889 | 3081.585 | 3436.284 |
fastgpt | 1238.585 | 1599.724 | 1921.785 | 2256.674 |
- 生成长度 64 评测
模型框架 | beam:1 | beam:2 | beam:3 | beam:4 |
---|---|---|---|---|
torch | 3257.929 | 4274.201 | 4256.85 | 4677.168 |
fastgpt | 2510.484 | 3081.851 | 2697.296 | 3150.157 |
model name : Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz cpu cores : 2
详见 GITHUB ACTIONS 的cml 报告
补充
- 对于CodeGen系列代码生成模型,官方是不支持
transformers
库的,因此不能直接用fastgpt
加载,请转至example/codegen进行onnx量化和文本生成
感谢
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
fastgpt-0.0.7.tar.gz
(11.9 kB
view details)
File details
Details for the file fastgpt-0.0.7.tar.gz
.
File metadata
- Download URL: fastgpt-0.0.7.tar.gz
- Upload date:
- Size: 11.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d4ba23ec786d73cc7642f616323d3af7a483c9e9559f44c306223d1d211d99fd |
|
MD5 | 29d7ae4d631691216be8f44f6c4defdb |
|
BLAKE2b-256 | 04559192d2531a08d79847d4e6d2f165b6716bf2d6f0d6898b3eb846174186e2 |