Skip to main content

boost inference speed of GPT models in transformers by onnxruntime

Project description

fastgpt

PyPI - Python Version PyPI PyPI GitHub license badge Blog Codecov

fastgpt 是什么

  • fastgpt是一个基于transformersonnxruntimepython库,可以无缝衔接的使用 onnxruntime 量化后的 transfromers GPT 模型做文本生成任务,提高推理速度、降低资源成本。

fastgpt 的背景

  • GPT模型是通过序列文本预测下一个词的训练任务得到的预训练模型,可以在文本生成任务上达到很好的效果。
  • transformers库是近些年最火的做预训练模型的 python 库,在其背后的社区,网友、组织分享开源了各式各样的预训练模型,尤其是截止 2022 年 6 月 23 日,社区的开源文本生成模型多达到5068个。
  • onnx是由微软,亚马逊 ,Facebook 和 IBM 等公司共同开发的,针对机器学习所设计的开放式的文件格式,经过 onnxruntime 量化压缩的预训练模型,在 cpu 硬件上推理速度在各开源框架的对比中首屈一指。
  • 然而,通过transformers官方的 onnx 接口转换、onnx 量化 API,却没有做好 GPT 模型转换的兼容问题,经常转换失败。而手动进行 onnx 转换需要自定义很多配置,对于新手不很友好。
  • fastgpt库,就是为了无缝衔接 transformers 库调用 GPT 模型转换 onnx 格式推理,使用者可以在仅修改两行代码的情况下,使用 onnx 量化后的 GPT 模型,做 transformers 库的文本生成函数。
  • 原 transformers 代码:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
  • fastgpt 代码:
from fastgpt import CausalLMModelForOnnxGeneration
model = CausalLMModelForOnnxGeneration.from_pretrained("distilgpt2")
  • 在 fastgpt 这一行代码中,会执行以下流程
    1. transformers hub 的模型下载
    2. pytorch 模型推理,输出 logits
    3. onnx 格式转换
    4. onnx 格式模型推理,输出 logits,进行对比差异
    5. onnx 量化
    6. onnx 量化格式模型推理,输出 logits,进行对比差异
    7. 把兼容 transformers 文本生成函数的 onnx 格式 GPT 模型,包装到 model 中

安装

pip install fastgpt

快速 demo

from transformers import AutoTokenizer
from fastgpt import CausalLMModelForOnnxGeneration
model = CausalLMModelForOnnxGeneration.from_pretrained("distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")

prompt_text = "Natural language processing (NLP) is the ability of a computer program to understand human language as it is spoken and written"
input_ids = tokenizer(
    prompt_text, return_tensors="pt", add_special_tokens=False
).input_ids

generated_ids = model.generate(   # 这里完全兼容transformers的generate函数
    input_ids,
    max_length=64 + input_ids.shape[1],
    decoder_start_token_id=tokenizer.cls_token_id,
    eos_token_id=tokenizer.sep_token_id,
    output_scores=True,
    temperature=1,
    repetition_penalty=1.0,
    top_k=50,
    top_p=0.9,
    do_sample=True,
    num_return_sequences=1,
    length_penalty=2.0,
    early_stopping=True,
)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
print("=" * 20)

fastgpt 的优点

  1. 兼容 transformers: 基于 transformers 库的文本生成函数,功能非常丰富。fastgpt 在 onnx 格式模型上,兼容该函数。
  2. 兼容 cache: 在文本生成的一个个 token 生成过程中的past_key_value需要在 GPT 模型上持续迭代输入,fastgpt 已经通过 onnx 格式做好衔接。
  3. 代码修改低成本:代码替换原版 transformers 仅需修改两行代码。
  4. onnx 格式占内存小:对于 distilgpt2 模型,torch 版318MB, onnx 量化版243MB
  5. cpu 上速度更快: 用时约降低 33%

生成速度评测(ms)

  • 生成长度 4 评测
模型框架 beam:1 beam:2 beam:3 beam:4
torch 290.779 475.693 560.458 648.756
fastgpt 195.265 292.272 378.933 466.14

  • 生成长度 8 评测
模型框架 beam:1 beam:2 beam:3 beam:4
torch 482.199 817.065 905.646 1052.983
fastgpt 341.735 471.028 583.264 713.009

  • 生成长度 16 评测
模型框架 beam:1 beam:2 beam:3 beam:4
torch 878.338 1518.198 1619.336 1813.197
fastgpt 635.157 838.787 1009.497 1210.047

  • 生成长度 32 评测
模型框架 beam:1 beam:2 beam:3 beam:4
torch 1661.819 2854.889 3081.585 3436.284
fastgpt 1238.585 1599.724 1921.785 2256.674

  • 生成长度 64 评测
模型框架 beam:1 beam:2 beam:3 beam:4
torch 3257.929 4274.201 4256.85 4677.168
fastgpt 2510.484 3081.851 2697.296 3150.157

model name : Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz cpu cores : 2

详见 GITHUB ACTIONS 的cml 报告

补充

  1. 对于CodeGen系列代码生成模型,官方是不支持transformers库的,因此不能直接用fastgpt加载,请转至example/codegen进行onnx量化和文本生成

感谢

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fastgpt-0.0.4.tar.gz (11.9 kB view details)

Uploaded Source

File details

Details for the file fastgpt-0.0.4.tar.gz.

File metadata

  • Download URL: fastgpt-0.0.4.tar.gz
  • Upload date:
  • Size: 11.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for fastgpt-0.0.4.tar.gz
Algorithm Hash digest
SHA256 a27d3ede01d4d31a31022bfbfa8c7339817aebae6cec0479fde0227954c9ed02
MD5 10bee3bc96465e188fa8fcf160980eac
BLAKE2b-256 70f40a5c21b44eb3c9f09dad6c10fc9952223ca36273ab31ab9916ca13c3a107

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page