Skip to main content

an elegant bert4torch

Project description

bert4torch

licence GitHub release PyPI PyPI - Downloads GitHub stars GitHub Issues contributions welcome Generic badge

Documentation | Torch4keras | Examples | build_MiniLLM_from_scratch

目录

1. 下载安装

安装稳定版

pip install bert4torch

安装最新版

pip install git+https://github.com/Tongjilibo/bert4torch
  • 注意事项:pip包的发布慢于git上的开发版本,git clone注意引用路径,注意权重是否需要转换
  • 测试用例git clone https://github.com/Tongjilibo/bert4torch,修改example中的预训练模型文件路径和数据路径即可启动脚本
  • 自行训练:针对自己的数据,修改相应的数据处理代码块
  • 开发环境:原使用torch==1.10版本进行开发,现已切换到torch2.0开发,如其他版本遇到不适配,欢迎反馈

2. 功能

  • LLM模型: 加载chatglm、llama、 baichuan、ziya、bloom等开源大模型权重进行推理和微调

  • 核心功能:加载bert、roberta、albert、xlnet、nezha、bart、RoFormer、RoFormer_V2、ELECTRA、GPT、GPT2、T5、GAU-alpha、ERNIE等预训练权重继续进行finetune、并支持在bert基础上灵活定义自己模型

  • 丰富示例:包含llmpretrainsentence_classficationsentence_embeddingsequence_labelingrelation_extractionseq2seqserving等多种解决方案

  • 实验验证:已在公开数据集实验验证,使用如下examples数据集

  • 易用trick:集成了常见的trick,即插即用

  • 其他特性加载transformers库模型一起使用;调用方式简洁高效;有训练进度条动态展示;配合torchinfo打印参数量;默认Logger和Tensorboard简便记录训练过程;自定义fit过程,满足高阶需求

  • 训练过程

    2022-10-28 23:16:10 - Start Training
    2022-10-28 23:16:10 - Epoch: 1/2
    5000/5000 [==============================] - 13s 3ms/step - loss: 0.1351 - acc: 0.9601
    Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 798.09it/s] 
    test_acc: 0.98045. best_test_acc: 0.98045
    
    2022-10-28 23:16:27 - Epoch: 2/2
    5000/5000 [==============================] - 13s 3ms/step - loss: 0.0465 - acc: 0.9862
    Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 635.78it/s] 
    test_acc: 0.98280. best_test_acc: 0.98280
    
    2022-10-28 23:16:44 - Finish Training
    
功能 bert4torch transformers 备注
训练进度条 进度条打印loss和定义的metrics
分布式训练dp/ddp torch自带dp/ddp
各类callbacks 日志/tensorboard/earlystop/wandb等
大模型推理,stream/batch输出 各个模型是通用的,无需单独维护脚本
大模型微调 lora依赖peft库,pv2自带
丰富tricks 对抗训练等tricks即插即用
代码简洁易懂,自定义空间大 代码复用度高, keras代码训练风格
仓库的维护能力/影响力/使用量/兼容性 目前仓库个人维护

3. 快速上手

4. 版本和更新历史

4.1 版本历史

更新日期 bert4torch torch4keras 版本说明
20240418 0.5.0 0.2.2 修复chatglm3的bug, 修复save_pretrained时多文件的bug,增加CausalLMLoss, 修改deepspeed的传参逻辑,修改Text2Vec的bug, 完善openai client, 增加get_weight_decay_optim_groups
20240317 0.4.9.post2 0.2.1.post2 增加get_weight_decay_optim_groups函数, attention中允许is_causal,修改repetition_penalty的bug,把baichuan从llama中剥离,修复config_path的bug,允许num_key_value_heads参数,torch4keras-v0.2.1.post2更新特性
20240221 0.4.8 0.2.0 fastapi发布服务允许闲时offload到cpu, build_transformer_model允许从hf下载, 添加FillMask的pipeline, 添加SequenceClassificationTrainer

更多版本

4.2 更新历史

更多历史

5. 预训练权重

  • 预训练模型支持多种代码加载方式
from bert4torch.models import build_transformer_model

# 1. 仅指定config_path: 从头初始化模型结构, 不加载预训练模型
model = build_transformer_model('./model/bert4torch_config.json')

# 2. 仅指定checkpoint_path: 
## 2.1 文件夹路径: 自动寻找路径下的*.bin/*.safetensors权重文件 + bert4torch_config.json/config.json文件
model = build_transformer_model(checkpoint_path='./model')

## 2.2 文件路径/列表: 文件路径即权重路径/列表, config会从同级目录下寻找
model = build_transformer_model(checkpoint_path='./pytorch_model.bin')

## 2.3 model_name: hf上预训练权重名称, 会自动下载hf权重以及bert4torch_config.json文件
model = build_transformer_model(checkpoint_path='bert-base-chinese')

# 3. 同时指定config_path和checkpoint_path(本地路径名或model_name排列组合): 
config_path = './model/bert4torch_config.json'  # 或'bert-base-chinese'
checkpoint_path = './model/pytorch_model.bin'  # 或'bert-base-chinese'
model = build_transformer_model(config_path, checkpoint_path)
模型分类 模型名称 权重来源 权重链接/checkpoint_path config_path
bert bert-base-chinese google-bert bert-base-chinese bert-base-chinese
chinese_L-12_H-768_A-12 谷歌 github, tf, Tongjilibo/bert-chinese_L-12_H-768_A-12
chinese-bert-wwm-ext HFL githubhfl/chinese-bert-wwm-ext chinese-bert-wwm-ext
bert-base-multilingual-cased google-bert bert-base-multilingual-cased bert-base-multilingual-cased
macbert HFL githubhfl/chinese-macbert-base, hfl/chinese-macbert-large chinese-macbert-base, chinese-macbert-large
wobert 追一科技 githubjunnyu/wobert_chinese_basejunnyu/wobert_chinese_plus_base wobert_chinese_base, wobert_chinese_plus_base
roberta chinese-roberta-wwm-ext HFL githubhfl/chinese-roberta-wwm-ext, hfl/chinese-roberta-wwm-ext-large chinese-roberta-wwm-ext, chinese-roberta-wwm-ext-large
roberta-small/tiny 追一科技 githubTongjilibo/chinese_roberta_L-4_H-312_A-12, Tongjilibo/chinese_roberta_L-6_H-384_A-12
roberta-base FacebookAI roberta-base roberta-base
guwenbert ethanyt ethanyt/guwenbert-base guwenbert-base
albert albert brightmart githubtorch, voidful/albert_chinese_tinyvoidful/albert_chinese_small, voidful/albert_chinese_base, voidful/albert_chinese_large, voidful/albert_chinese_xlarge, voidful/albert_chinese_xxlarge albert_chinese_tinyalbert_chinese_small, albert_chinese_base, albert_chinese_large, albert_chinese_xlarge, albert_chinese_xxlarge
nezha NEZHA 华为 githubtorch, sijunhe/nezha-cn-base, sijunhe/nezha-cn-large, sijunhe/nezha-base-wwm, sijunhe/nezha-large-wwm nezha-cn-base, nezha-cn-large, nezha-base-wwm, nezha-large-wwm
nezha_gpt_dialog bojone github, Tongjilibo/nezha_gpt_dialog
xlnet chinese-xlnet HFL github, hfl/chinese-xlnet-base chinese-xlnet-base
tranformer_xl huggingface transfo-xl/transfo-xl-wt103 transfo-xl-wt103
deberta Erlangshen-DeBERTa-v2 IDEA IDEA-CCNL/Erlangshen-DeBERTa-v2-97M-Chinese, IDEA-CCNL/Erlangshen-DeBERTa-v2-320M-Chinese, IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese Erlangshen-DeBERTa-v2-97M-Chinese, Erlangshen-DeBERTa-v2-320M-Chinese, Erlangshen-DeBERTa-v2-710M-Chinese
electra Chinese-ELECTRA HFL githubhfl/chinese-electra-base-discriminator chinese-electra-base-discriminator
ernie ernie 百度文心 paddlenghuyong/ernie-1.0-base-zh, nghuyong/ernie-3.0-base-zh ernie-1.0-base-zh, ernie-3.0-base-zh
roformer roformer 追一科技 githubjunnyu/roformer_chinese_base roformer_chinese_base
roformer_v2 追一科技 githubjunnyu/roformer_v2_chinese_char_base roformer_v2_chinese_char_base
simbert simbert 追一科技 githubTongjilibo/simbert-chinese-base, Tongjilibo/simbert-chinese-small, Tongjilibo/simbert-chinese-tiny
simbert_v2/roformer-sim 追一科技 githubjunnyu/roformer_chinese_sim_char_basejunnyu/roformer_chinese_sim_char_ft_basejunnyu/roformer_chinese_sim_char_smalljunnyu/roformer_chinese_sim_char_ft_small roformer_chinese_sim_char_base, roformer_chinese_sim_char_ft_base, roformer_chinese_sim_char_small, roformer_chinese_sim_char_ft_small
gau GAU-alpha 追一科技 github, Tongjilibo/chinese_GAU-alpha-char_L-24_H-768
uie uie 百度 github, torch, Tongjilibo/uie-base
gpt CDial-GPT thu-coai github, thu-coai/CDial-GPT_LCCC-base, thu-coai/CDial-GPT_LCCC-large CDial-GPT_LCCC-base, CDial-GPT_LCCC-large
cmp_lm(26亿) 清华 github, TsinghuaAI/CPM-Generate CPM-Generate
nezha_gen huawei_noah github, Tongjilibo/chinese_nezha_gpt_L-12_H-768_A-12
gpt2-chinese-cluecorpussmall UER uer/gpt2-chinese-cluecorpussmall gpt2-chinese-cluecorpussmall
gpt2-ml imcaspar tf, torch, BaiduYun(84dh) gpt2-ml_15g_corpus, gpt2-ml_30g_corpus
bart bart_base_chinese 复旦fnlp github, v1.0, fnlp/bart-base-chinese bart-base-chinese, bart-base-chinese-v1.0
t5 t5 UER uer/t5-small-chinese-cluecorpussmall, uer/t5-base-chinese-cluecorpussmall t5-base-chinese-cluecorpussmall, t5-small-chinese-cluecorpussmall
mt5 谷歌 google/mt5-base mt5-base
t5_pegasus 追一科技 github, Tongjilibo/chinese_t5_pegasus_small, Tongjilibo/chinese_t5_pegasus_base
chatyuan v1&v2 clue-ai github, ClueAI/ChatYuan-large-v1, ClueAI/ChatYuan-large-v2 ChatYuan-large-v1, ChatYuan-large-v2
PromptCLUE clue-ai github, ClueAI/PromptCLUE-base PromptCLUE-base
chatglm chatglm-6b THUDM github, THUDM/chatglm-6b, THUDM/chatglm-6b-int8, THUDM/chatglm-6b-int4, v0.1.0 chatglm-6b, chatglm-6b-int8, chatglm-6b-int4, chatglm-6b-v0.1.0
chatglm2-6b THUDM github, THUDM/chatglm2-6b, THUDM/chatglm2-6b-int4, THUDM/chatglm2-6b-32k chatglm2-6b, chatglm2-6b-int4, chatglm2-6b-32k
chatglm3-6b THUDM github, THUDM/chatglm3-6b, THUDM/chatglm3-6b-32k chatglm3-6b, chatglm3-6b-32k
llama llama meta github llama-7b, llama-13b
llama-2 meta github, meta-llama/Llama-2-7b-hf, meta-llama/Llama-2-7b-chat-hf, meta-llama/Llama-2-13b-hf, meta-llama/Llama-2-13b-chat-hf Llama-2-7b-hf, Llama-2-7b-chat-hf, Llama-2-13b-hf, Llama-2-13b-chat-hf
chinese_llama_alpaca HFL github chinese_alpaca_plus_7b, chinese_llama_plus_7b
Belle_llama LianjiaTech github, BelleGroup/BELLE-LLaMA-7B-2M-enc 合成说明BELLE-LLaMA-7B-2M-enc
Ziya IDEA-CCNL IDEA-CCNL/Ziya-LLaMA-13B-v1, IDEA-CCNL/Ziya-LLaMA-13B-v1.1, IDEA-CCNL/Ziya-LLaMA-13B-Pretrain-v1 Ziya-LLaMA-13B-v1, Ziya-LLaMA-13B-v1.1
Baichuan baichuan-inc github, baichuan-inc/Baichuan-7B, baichuan-inc/Baichuan-13B-Base, baichuan-inc/Baichuan-13B-Chat Baichuan-7B, Baichuan-13B-Base, Baichuan-13B-Chat
Baichuan2 baichuan-inc github, baichuan-inc/Baichuan2-7B-Base, baichuan-inc/Baichuan2-7B-Chat, baichuan-inc/Baichuan2-13B-Base, baichuan-inc/Baichuan2-13B-Chat Baichuan2-7B-Base, Baichuan2-7B-Chat, Baichuan2-13B-Base, Baichuan2-13B-Chat
vicuna lmsys lmsys/vicuna-7b-v1.5 vicuna-7b-v1.5
Yi 01-ai github, 01-ai/Yi-6B, 01-ai/Yi-6B-200K Yi-6B, Yi-6B-200K
bloom bloom bigscience bigscience/bloom-560m, bigscience/bloomz-560m bloom-560m, bloomz-560m
Qwen Qwen 阿里云 github, Qwen/Qwen-1_8B, Qwen/Qwen-1_8B-Chat, Qwen/Qwen-7B, Qwen/Qwen-7B-Chat Qwen-1_8B, Qwen-1_8B-Chat, Qwen-7B, Qwen-7B-Chat
InternLM InternLM 上海人工智能实验室 github, internlm/internlm-chat-7b, internlm/internlm-7b internlm-7b, internlm-chat-7b
Falcon Falcon tiiuae hf, tiiuae/falcon-rw-1b, tiiuae/falcon-7b, tiiuae/falcon-7b-instruct falcon-rw-1b, falcon-7b, falcon-7b-instruct
moe deeoseek-moe deepseek github, deepseek-ai/deepseek-moe-16b-base, deepseek-ai/deepseek-moe-16b-chat deepseek-moe-16b-base, deepseek-moe-16b-chat
embedding text2vec-base-chinese shibing624 shibing624/text2vec-base-chinese text2vec-base-chinese
m3e moka-ai moka-ai/m3e-base m3e-base
bge BAAI BAAI/bge-large-en-v1.5, BAAI/bge-large-zh-v1.5 bge-large-en-v1.5, bge-large-zh-v1.5
gte thenlper thenlper/gte-large-zh, thenlper/gte-base-zh gte-base-zh, gte-large-zh

*注:

  1. 高亮格式(如bert-base-chinese)的表示可直接build_transformer_model()联网下载
  2. 国内镜像网站加速下载
    • HF_ENDPOINT=https://hf-mirror.com python your_script.py
    • export HF_ENDPOINT=https://hf-mirror.com后再执行python代码
    • 在python代码开头如下设置
    import os
    os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"
    

6. 鸣谢

  • 感谢苏神实现的bert4keras,本实现有不少地方参考了bert4keras的源码,在此衷心感谢大佬的无私奉献;
  • 其次感谢项目bert4pytorch,也是在该项目的指引下给了我用pytorch来复现bert4keras的想法和思路。

7. 引用

@misc{bert4torch,
  title={bert4torch},
  author={Bo Li},
  year={2022},
  howpublished={\url{https://github.com/Tongjilibo/bert4torch}},
}

8. 其他

  • Wechat & Star History Chart
pic
微信号
pic
微信群
pic
Star History Chart

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bert4torch-0.5.0.tar.gz (176.3 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page