Skip to main content

an elegant bert4torch

Project description

bert4torch

licence GitHub release PyPI PyPI - Downloads GitHub stars GitHub Issues contributions welcome Generic badge

Documentation | Torch4keras | Examples

1. 下载安装

安装稳定版

pip install bert4torch

安装最新版

pip install git+https://github.com/Tongjilibo/bert4torch
  • 注意事项:pip包的发布慢于git上的开发版本,git clone注意引用路径,注意权重是否需要转换
  • 测试用例git clone https://github.com/Tongjilibo/bert4torch,修改example中的预训练模型文件路径和数据路径即可启动脚本
  • 自行训练:针对自己的数据,修改相应的数据处理代码块
  • 开发环境:原使用torch==1.10版本进行开发,现已切换到torch2.0开发,如其他版本遇到不适配,欢迎反馈

2. 功能

  • LLM模型: 加载chatglm、llama、 baichuan、ziya、bloom等开源大模型权重进行推理和微调

  • 核心功能:加载bert、roberta、albert、xlnet、nezha、bart、RoFormer、RoFormer_V2、ELECTRA、GPT、GPT2、T5、GAU-alpha、ERNIE等预训练权重继续进行finetune、并支持在bert基础上灵活定义自己模型

  • 丰富示例:包含llmpretrainsentence_classficationsentence_embeddingsequence_labelingrelation_extractionseq2seqserving等多种解决方案

  • 实验验证:已在公开数据集实验验证,使用如下examples数据集

  • 易用trick:集成了常见的trick,即插即用

  • 其他特性加载transformers库模型一起使用;调用方式简洁高效;有训练进度条动态展示;配合torchinfo打印参数量;默认Logger和Tensorboard简便记录训练过程;自定义fit过程,满足高阶需求

  • 训练过程

    2022-10-28 23:16:10 - Start Training
    2022-10-28 23:16:10 - Epoch: 1/2
    5000/5000 [==============================] - 13s 3ms/step - loss: 0.1351 - acc: 0.9601
    Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 798.09it/s] 
    test_acc: 0.98045. best_test_acc: 0.98045
    
    2022-10-28 23:16:27 - Epoch: 2/2
    5000/5000 [==============================] - 13s 3ms/step - loss: 0.0465 - acc: 0.9862
    Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 635.78it/s] 
    test_acc: 0.98280. best_test_acc: 0.98280
    
    2022-10-28 23:16:44 - Finish Training
    
功能 bert4torch transformers 备注
训练进度条 进度条打印loss和定义的metrics
分布式训练dp/ddp torch自带dp/ddp
各类callbacks 日志/tensorboard/earlystop/wandb等
大模型推理,stream/batch输出 各个模型是通用的,无需单独维护脚本
大模型微调 lora依赖peft库,pv2自带
丰富tricks 对抗训练等tricks即插即用
代码简洁易懂,自定义空间大 代码复用度高, keras代码训练风格
仓库的维护能力/影响力/使用量/兼容性 目前仓库个人维护

3. 快速上手

4. 版本历史

更新日期 bert4torch torch4keras 版本说明
20240204 0.4.7 0.1.9 修改save_pretrained用于保存文件夹, 增加GenerateSpeed用于统计token生成速度,修复t5在use_states=True时候的错误, 修改层次编码的bug, 增加deepseek_moe模型,修复generation并发错误,优化大模型耗时
20240116 0.4.6 0.1.8 bug修复,增加save_pretrained用于保存transformer格式的权重, 增加部分embedding模型
20240111 0.4.5 0.1.7 training时候不生成past_key_values, 增加streamlit的example, 修复句向量max时的bug, batch_generate合并到generate, 修改generation的默认参数名(兼容过去的参数名), 多轮对话中可保留past_key_values, 把attention中的mask补齐逻辑移到apply_embedding中, 增加uiepipeline,增加PtuningV2Trainer

更多版本

5. 更新历史:

更多历史

6. 预训练权重

  • 若无说明则使用权重自带的pytorch_model.binconfig.json
模型分类 模型名称 权重来源 权重链接 备注(若有)
bert bert-base-chinese 谷歌bert的torch版 torch config
chinese_L-12_H-768_A-12 谷歌 github, tf 转换命令, config
chinese-bert-wwm-ext HFL tf/torchtorch
bert-base-multilingual-cased huggingface torch config
macbert HFL tf/torchtorch
wobert 追一科技 tftorch_basetorch_plus_base
roberta chinese-roberta-wwm-ext HFL tf/torchtorch
roberta-small/tiny 追一科技 & UER tftorch 转换脚本
roberta-base-english huggingface torch config
guwenbert ethanyt torch config
albert albert brightmart tftorchtorch
nezha NEZHA 华为 tftorch
xlnet chinese-xlnet HFL tf/torch config
deberta Erlangshen-DeBERTa-v2 IDEA torch
electra Chinese-ELECTRA HFL tftorch
ernie ernie 百度文心 paddletorch
roformer roformer 追一科技 tftorch config
roformer_v2 追一科技 tftorch config
simbert simbert 追一科技 tftorch_base 转换脚本 config
simbert_v2/roformer-sim 追一科技 tfbaseft_basesmallft_small 转换脚本, config
gau GAU-alpha 追一科技 tf 转换脚本
uie uie 百度 github, torch 转换脚本
gpt CDial-GPT thu-coai torch config
gpt2 cmp_lm(26亿) 清华 github, torch config
gpt2-chinese-cluecorpussmall UER torch config
gpt2-ml imcaspar tftorch config
bart bart_base_chinese 复旦fnlp torch, v1.0, v2.0 config
t5 t5 UER torch base, small
mt5 谷歌 torch config
t5_pegasus 追一科技 tf base, small
chatyuan v1&v2 clue-ai torch config
PromptCLUE clue-ai torch config
chatglm chatglm-6b THUDM github, v0.1.0, v1.1.0, int8, int4 config
chatglm2-6b THUDM github, v2, int4, 32k config
chatglm3-6b THUDM github, v3, 32k config
llama llama facebook github config
llama-2 facebook github, 7b, 7b-chat, 13b, 13b-chat config
chinese_llama_alpaca HFL github config
Belle_llama LianjiaTech github, 7B-2M-enc 合成说明config
Ziya IDEA-CCNL v1, v1.1, pretrain-v1 config
Baichuan baichuan-inc github, 7B, 13B-Base, 13B-Chat config
Baichuan2 baichuan-inc github, 7B-Base, 7B-Chat, 13B-Base, 13B-Chat config
vicuna lmsys 7b-v1.5 config
Yi 01-ai github, 6B, 6B-200K config
bloom bloom bigscience bloom-560m, bloomz-560m config
Qwen Qwen 阿里云 github, 1.8B, 1.8B-Chat, 7B, 7B-Chat config
InternLM InternLM 上海人工智能实验室 github, 7B-Chat, 7B config
Falcon Falcon tiiuae hf, RW-1B, 7B, 7B-Instruct config
moe deeoseek-moe deepseek github, moe-16b-base, moe-16b-chat config
embedding text2vec-base-chinese shibing624 base base
m3e moka-ai base base
bge BAAI large-en-v1.5, large-zh-v1.5 large-en-v1.5, large-zh-v1.5
gte thenlper large-zh, base-zh large-zh, base-zh

7. 鸣谢

  • 感谢苏神实现的bert4keras,本实现有不少地方参考了bert4keras的源码,在此衷心感谢大佬的无私奉献;
  • 其次感谢项目bert4pytorch,也是在该项目的指引下给了我用pytorch来复现bert4keras的想法和思路。

8. 引用

@misc{bert4torch,
  title={bert4torch},
  author={Bo Li},
  year={2022},
  howpublished={\url{https://github.com/Tongjilibo/bert4torch}},
}

9. 其他

  • Wechat & Star History Chart
pic
微信号
pic
微信群
pic
Star History Chart

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bert4torch-0.4.7.tar.gz (155.0 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page