Skip to main content

an elegant bert4torch

Project description

bert4torch

licence GitHub release PyPI PyPI - Downloads GitHub stars GitHub Issues contributions welcome Generic badge

Documentation | Torch4keras | Examples

1. 下载安装

安装稳定版

pip install bert4torch

安装最新版

pip install git+https://github.com/Tongjilibo/bert4torch
  • 注意事项:pip包的发布慢于git上的开发版本,git clone注意引用路径,注意权重是否需要转换
  • 测试用例git clone https://github.com/Tongjilibo/bert4torch,修改example中的预训练模型文件路径和数据路径即可启动脚本
  • 自行训练:针对自己的数据,修改相应的数据处理代码块
  • 开发环境:原使用torch==1.10版本进行开发,现已切换到torch2.0开发,如其他版本遇到不适配,欢迎反馈

2. 功能

  • LLM模型: 加载chatglm、llama、 baichuan、ziya、bloom等开源大模型权重进行推理和微调

  • 核心功能:加载bert、roberta、albert、xlnet、nezha、bart、RoFormer、RoFormer_V2、ELECTRA、GPT、GPT2、T5、GAU-alpha、ERNIE等预训练权重继续进行finetune、并支持在bert基础上灵活定义自己模型

  • 丰富示例:包含llmpretrainsentence_classficationsentence_embeddingsequence_labelingrelation_extractionseq2seqserving等多种解决方案

  • 实验验证:已在公开数据集实验验证,使用如下examples数据集

  • 易用trick:集成了常见的trick,即插即用

  • 其他特性加载transformers库模型一起使用;调用方式简洁高效;有训练进度条动态展示;配合torchinfo打印参数量;默认Logger和Tensorboard简便记录训练过程;自定义fit过程,满足高阶需求

  • 训练过程

    2022-10-28 23:16:10 - Start Training
    2022-10-28 23:16:10 - Epoch: 1/2
    5000/5000 [==============================] - 13s 3ms/step - loss: 0.1351 - acc: 0.9601
    Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 798.09it/s] 
    test_acc: 0.98045. best_test_acc: 0.98045
    
    2022-10-28 23:16:27 - Epoch: 2/2
    5000/5000 [==============================] - 13s 3ms/step - loss: 0.0465 - acc: 0.9862
    Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 635.78it/s] 
    test_acc: 0.98280. best_test_acc: 0.98280
    
    2022-10-28 23:16:44 - Finish Training
    
功能 bert4torch transformers 备注
训练进度条 进度条打印loss和定义的metrics
分布式训练dp/ddp torch自带dp/ddp
各类callbacks 日志/tensorboard/earlystop/wandb等
大模型推理,stream/batch输出 各个模型是通用的,无需单独维护脚本
大模型微调 lora依赖peft库,pv2自带
丰富tricks 对抗训练等tricks即插即用
代码简洁易懂,自定义空间大 代码复用度高, keras代码训练风格
仓库的维护能力/影响力/使用量/兼容性 目前仓库个人维护

3. 快速上手

4. 版本历史

4.1 版本历史
更新日期 bert4torch版本 torch4keras版本 版本说明
20230902 0.3.4 0.1.3 修复gradient_checkpoint在低版本torch时仅支持位置参数的问题, 增加trainer.py, 增加PPOTrainerTrl以及相应的三阶段rlhf训练+dpo训练
20230812 0.3.3 0.1.2 增加大模型deepspeed的使用,增加Qwen模型(增加ntk和logn_attn),generation的end_id支持多个token_id,修复多文件权重加载资源占用问题
20230804 0.3.2 0.1.1 修改依赖的torch4keras, 主要是进度条和logger, tensorboard的同步
20230726 0.3.1.post2 0.1.0.post2 修改baichuan的alibi逻辑,增加bloom, 简化decoder架构代码(gpt, llama, chatglm均继承decoder)
20230716 0.3.0 0.0.9 修改models和layers为文件夹方便扩展, 增加flash_attention参数控制,修改skip_init逻辑减少显存占用,generation增加repetition_penalty,修复chatglm的pv2的bug,generation支持transformers的tokenize,增加ziya,Baichuan
20230705 0.2.9 0.0.8 使用accelerate来实现skip_init精简代码, 修复add_trainer的代码提示, 增加chatglm的load_in_8bit+lora/qlora的训练, 修复grad_chechpoint, 增加chinese_llama_alpaca, torch2.0默认使用scaled_dot_product_attention加速, 增加chatglm2-6b+pv2+lora微调
20230518 0.2.8 0.0.7.post3 1)新增模型: 增加chatglm-6b/llama-7b/BELLE_llama/vicuna/moss/苏神、uer的roberta-small/Tiny模型以及ChatYuan v2模型/fnlp的bart2.0, 增加量化模块并适配llama,增加skip_init参数加快加载, 增加stream输出/网页demo, 增加ptuning_v2和lora;
2)generation: 生成式解码新增SeqGeneration和Seq2SeqGeneration,单向decoder模型和encoder decoder模型解码增加cache, 增加batch_generate()/stream_generate功能;
3)其他: 修改rope为不使用max_position,修复model.half()类型不一致问题,支持加载多个权重文件, gpt系列默认不加softmax,增加苏神Tiger的pytorch实现, 增加了对attention_key_size的入参支持,把_token_pad_ids重命名为pad_token_ids, tokenizor中重命名部分字段
20230310 0.2.7.post2 0.0.6 增加lion优化器, 修复albert_unshared加载权重, 修复lm系列(gpt, seq2seq)存在的forward参数不对的问题,修复GlobalPointer使用rope的bug
20230213 0.2.7 0.0.6 修复random_sample()的bug,适配v0.0.6的torch4keras:增加resume_from_checkpoint和save_to_checkpoint;增加add_trainer方法,重构了Trainer(BaseModel)的实现,增加了AccelerateCallback
20221231 0.2.6 0.0.5 build_transformer_model需显式指定add_trainer才从BaseModel继承, 增加guwenbert, macbert,text2vec-bert-chinese, wobert预训练模型,允许position_ids从padding开始, transformer.configs支持点操作,可以使用torch4keras的Trainer(net)来初始化, 修复tokenizer的切分subtoken的bug, 允许embedding_size!=hidden_size
20221127 0.2.5 0.0.4 对抗训练从compile转为使用Callback来实现,修复1.7.1版本兼容bug, uie模型内置
20221120 0.2.4 0.0.3.post2 删除SpTokenizer基类中的rematch, 增加deberta_v2模型
20221023 0.2.3 0.0.2 虚拟对抗VAT在多个ouput时支持指定,把Trainer抽象到torch4keras中,修复DP和DDP出现resume_epoch不存在的bug, tokenizer的never_split去除None, transformer_xl的bug, 增加gradient_checkpoint
20220922 0.2.2 —— 修复t5的norm_mode问题,允许hidden_size不整除num_attention_heads,支持多个schedule(如同时ema+warmup)
20220905 0.2.1 —— 兼容torch<=1.7.1的torch.div无rounding_mode,增加自定义metrics,支持断点续训,增加默认Logger和Tensorboard日志
20220823 0.2.0 —— 兼容torch1.9.0的缺失take_along_dim,修复bart中位置向量514的问题,修复Sptokenizer对符号不转换,打印Epoch开始的时间戳,增加parallel_apply
20220808 0.1.9 —— 增加mixup/manifold_mixup/temporal_ensembling策略,修复pgd策略param.grad为空的问题,修改tokenizer支持批量
20220717 0.1.8 —— 修复原来CRF训练中loss陡增的问题,修复xlnet的token_type_ids输入显存占用大的问题
20220710 0.1.7 —— 增加EarlyStop,CRF中自带转bool类型
20220605 0.1.6 —— 增加transformer_xl、xlnet、t5_pegasus模型,prompt、预训练等示例,支持增加embedding输入,EMA策略,修复tokenizer和sinusoid的bug
20220504 0.1.5 —— 增加GAU-alpha,混合梯度,梯度裁剪,单机多卡(DP、DDP)
20220421 0.1.4 —— 增加了VAT,修复了linux下apply_embedding返回项有问题的情况
20220409 0.1.3 —— 初始版本

5. 更新历史:

5.1 即将更新
  • 接入更多开源大模型
5.2 更新历史
  • 20230902:修复gradient_checkpoint在低版本torch时仅支持位置参数的问题, 增加trainer.py, 增加PPOTrainerTrl以及相应的三阶段rlhf训练+dpo训练
  • 20230812:增加llama-2的微调, 增加大模型deepspeed的使用,增加Qwen模型(增加ntk和logn_attn),generation的end_id支持多个token_id,修复多文件权重加载资源占用问题
  • 20230726:修改baichuan的alibi逻辑,增加bloom, 简化decoder架构代码(gpt, llama, chatglm均继承decoder)
  • 20230716:修改models和layers为文件夹方便扩展, 增加flash_attention参数控制,增加chatglm-api示例,修改skip_init逻辑减少显存占用,generation增加repetition_penalty,修复chatglm的pv2的bug,generation支持transformers的tokenize,增加ziya,Baichuan
  • 20230705:使用accelerate来实现skip_init精简代码, 修复add_trainer的代码提示, 增加chatglm的load_in_8bit+lora/qlora的训练, 修复grad_chechpoint, 增加chinese_llama_alpaca, torch2.0默认使用scaled_dot_product_attention加速, 增加chatglm2-6b+pv2+lora微调
  • 20230518:增加vicuna的集成, 增加batch_generate()功能, 把_token_pad_ids重命名为pad_token_ids, tokenizor中重命名部分字段
  • 20230408:增加苏神Tiger的pytorch实现, 集成苏神、uer的roberta-small/Tiny模型以及ChatYuan v2模型, 增加了对attention_key_size的入参支持,单向decoder模型和encoder decoder模型解码增加cache, 更新fnlp的bart2.0, 增加chatglm-6b预训练模型推理, 集成BELLE_llama模型, 增加量化模块并适配llama,增加skip_init参数加快加载, 增加stream输出/网页demo, 增加ptuning_v2,增加moss模型的int4/int8推理
  • 20230326:增加llama-7b预训练模型, 修改rope为不使用max_position, 增加prompt_clue和nezha_gpt_dialog的finetune示例(skykiseki用户),修复model.half()类型不一致问题,生成式解码新增SeqGeneration和Seq2SeqGeneration, 支持加载多个权重文件, gpt系列默认不加softmax
  • 20230310:增加lion优化器, 修改dp和ddp示例更易用,增加PromptCLUE模型, 修复albert_unshared加载权重, 增加uer-gpt2-chinese预训练模型,修复lm系列(gpt, seq2seq)存在的forward参数不对的问题,修复GlobalPointer使用rope的bug
  • 20230212:兼容accelerate包, 增加ChatYuan v1模型,修复random_sample()的bug
  • 20221230:增加macbert,text2vec-bert-chinese, wobert模型,增加LEAR的ner示例, 增加PGRC、SPN4RE的关系提取示例,transformer.configs支持点操作,可以使用torch4keras的Trainer(net)来初始化, 修复tokenizer的切分subtoken的bug, 允许embedding_size!=hidden_size
  • 20221127:增加deberta_v2模型, 对抗训练从compile转为使用Callback来实现,修复1.7.1版本兼容bug, uie模型内置, 增加triton示例, build_transformer_model需显式指定add_trainer才从BaseModel继承, 增加guwenbert预训练模型,允许position_ids从padding开始
  • 20221102:增加CNN_Nested_NER示例, 删除SpTokenizer基类中的rematch
  • 20221022:修复DP和DDP出现resume_epoch不存在的bug, tokenizer的never_split去除None, transformer_xl的bug, 增加gradient_checkpoint
  • 20221011:虚拟对抗VAT在多个ouput时支持指定,增加elasticsearch示例, 把Trainer抽象到torch4keras中供更多项目使用,把梯度累积移到compile中
  • 20220920:增加TensorRT示例,支持多个schedule(如同时ema+warmup),sanic+onnx部署
  • 20220910:增加默认Logger和Tensorboard日志,ONNX推理,增加ERNIE模型,修复t5的norm_mode问题,允许hidden_size不整除num_attention_heads
  • 20220828:增加nl2sql示例,增加自定义metrics,支持断点续训
  • 20220821:增加W2NER和DiffCSE示例,打印Epoch开始的时间戳,增加parallel_apply,兼容torch<=1.7.1的torch.div无rounding_mode
  • 20220814:增加有监督句向量、关系抽取、文本生成实验指标,兼容torch<1.9.0的缺失take_along_dim,修复bart中位置向量514的问题,修复Sptokenizer对符号不转换
  • 20220727:增加mixup/manifold_mixup/temporal_ensembling策略,修复pgd策略param.grad为空的问题,修改tokenizer支持批量,增加uie示例
  • 20220716:修复原来CRF训练中loss陡增的问题,修复xlnet的token_type_ids输入显存占用大的问题
  • 20220710:增加金融中文FAQ示例,天池新闻分类top1案例,增加EarlyStop,CRF中自带转bool类型
  • 20220629:增加ner的实验,测试crf不同初始化的效果,bert-whitening中文实验
  • 20220613:增加seq2seq+前缀树,增加SimCSE/ESimCSE/PromptBert等无监督语义相似度的中文实验
  • 20220605:增加PromptBert、PET、P-tuning示例,修改tokenizer对special_tokens分词错误的问题,增加t5_pegasus
  • 20220529:transformer_xl、xlnet模型,修改sinusoid位置向量被init_weight的bug,EMA,sohu情感分类示例
  • 20220517:增加预训练代码,支持增加embedding输入(如词性,word粒度embedding)
  • 20220501:增加了混合梯度,梯度裁剪,单机多卡训练(DP、DDP)
  • 20220425:增加了VAT、GAU-alpha等示例,增加了梯度累积,自定义fit()示例
  • 20220415:增加了ner_mrc、ner_span、roformer_v2、roformer-sim等示例
  • 20220405:增加了GPLinker、TPlinker、SimBERT等示例
  • 20220329:增加了CoSENT、R-Drop、UDA等示例
  • 20220322:添加GPT、GPT2、T5模型
  • 20220312:初版提交

6. 预训练权重

6.1 已支持权重
模型分类 权重来源 权重链接 备注(若有)
bert 谷歌原版bert(即bert-base-chinese) tftorch tf转pytorch命令转换脚本
bert 哈工大chinese-bert-wwm-ext tf/torchtorch
bert-base-multilingual-cased huggingface torch 转换脚本
macbert 哈工大chinese-macbert-base/large tf/torchtorch
roberta 哈工大chinese-roberta-wwm-ext tf/torchtorch
roberta-small/tiny 追一科技 & UER tftorch 转换脚本
roberta-base (english) huggingface torch 转换脚本
deberta_v2 IDEA Erlangshen-DeBERTa-v2 torch 转换脚本
guwenbert 古文bert torch 转换脚本
xlnet 哈工大xlnet tf/torch config
electra 哈工大electra tftorch
macbert 哈工大macbert tftorch
albert brightmart tftorchtorch
ernie 百度文心 paddletorch
roformer 追一科技 tftorch
roformer_v2 追一科技 tftorch
simbert 追一科技 tftorch_base 转换脚本
simbert_v2/roformer-sim 追一科技 tftorch
gau-alpha 追一科技 tf 转换脚本
wobert 追一科技 tftorch_basetorch_plus_base
nezha 华为 tftorch
gpt thu-coai/CDial-GPT torch 转换脚本
gpt2 清华26亿 cmp_lm torch 转换脚本
gpt2 中文GPT2_ML模型 tftorch 转换脚本
gpt2 UER torch 转换脚本
t5 UER torch config
mt5 谷歌 torch config
t5_pegasus 追一科技 tf 转换脚本
bart 复旦 torch, v1.0, v2.0 转换脚本
text2vec text2vec-base-chinese torch
chatyuan v1&v2 clue-ai torch config
PromptCLUE clue-ai torch config
chatglm-6b THUDM github, v0.1.0, v1.1.0, int8, int4 转换脚本
chatglm2-6b THUDM github, v2, int4 转换脚本
llama facebook github 转换脚本
llama-2 facebook github, 7b, 7b-chat, 13b, 13b-chat 转换脚本
chinese_llama_alpaca Yiming Cui github 转换脚本
vicuna FastChat torch 转换脚本
Belle_llama LianjiaTech github, 7B-2M-enc 合成说明转换脚本
Ziya IDEA-CCNL v1, v1.1, pretrain-v1 转换脚本
Baichuan baichuan-inc 7B, 13B-Base, 13B-Chat 转换脚本
bloom bigscience bloom-560m, bloomz-560m 转换脚本
Qwen 阿里云 github, 7B, 7B-Chat 转换脚本

7. 鸣谢

  • 感谢苏神实现的bert4keras,本实现有不少地方参考了bert4keras的源码,在此衷心感谢大佬的无私奉献;
  • 其次感谢项目bert4pytorch,也是在该项目的指引下给了我用pytorch来复现bert4keras的想法和思路。

8. 引用

@misc{bert4torch,
  title={bert4torch},
  author={Bo Li},
  year={2022},
  howpublished={\url{https://github.com/Tongjilibo/bert4torch}},
}

9. 其他

  • Wechat Discussions
pic
微信号
pic
微信群
  • Star History Chart
pic
Star History Chart

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bert4torch-0.3.4.tar.gz (119.0 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page