训练/使用Bert分类模型
Project description
About
一个用于训练和调用bert模型完成数据分类的python包。
Install
$ pip3 install -U berttextclassification
Director
- bert
- bert_lawlaw.py
- DataLawlaw.py
- train.py
- evaluate.py
- acc_and_f1.py
- BertModel.py
- predict.py
bert_lawlaw.py
main程序完成训练到评估,默认是只评估不进行训练。
DataLawlaw.py
定义了一个LawlawProcessor类,封装了一个,完成分类任务中数据相关的操作。
- LawlawProcessor类
- get_labels函数 五个设定好的分类标签
- load_train_examples函数 加载json格式的训练数据
- load_dev_examples函数 加载json格式的评估数据
train.py
对模型进行训练,并保存训练数据。
- train函数 根据设置的batch_size和epoch进行训练。将lr和loss写到tensorboard中,根据save_steps保存不同checkpoint的模型。
evaluate.py
对验证集进行评估,得到训练结果。
- evaluate函数 对选择的model进行评估,并给出acc_and_f1函数返回的三项分值参数。
acc_and_f1.py
计算bert模型预测后的分值。
- acc_and_f1函数 计算模型预测标签的三项分值参数acc,f1_score和acc_and_f1二者的平均值。
BertModel.py
对Bert模型的相关操作,包括设置种子、加载模型和保存模型。
- set_seed函数 设置种子数值。
- load_model函数 从output_dir加载model和tokenizer
- save_model函数 将model和tokenizer保存到output_dir
predict.py
对输入的文本进行分类预测。
- predict函数 使用加载的模型,对输入文本进行分类。
Usage
对模型进行训练,并保存模型
from BertModel import set_seed, init_model, load_model, save_model
from DataLawlaw import LawlawProcessor
from train import train
set_seed()
task = 'lawlaw'
processor = LawlawProcessor(task)
label_list = processor.get_labels()
model, tokenizer = init_model(task, len(label_list))
train_dataset = processor.load_train_examples(tokenizer, max_seq_length=128)
global_step, tr_loss = train(
model,
train_dataset,
train_batch_size=8,
num_train_epochs=5.0,
save_steps=1000,
output_dir='./outs'
)
save_model(model, tokenizer, output_dir='./outs')
对模型进行测试
from BertModel import set_seed, init_model, load_model, save_model
from DataLawlaw import LawlawProcessor
from evaluate import evaluate
set_seed()
task = 'lawlaw'
processor = LawlawProcessor(task)
model, tokenizer = load_model(output_dir='./outs')
eval_dataset = processor.load_dev_examples(tokenizer, max_seq_length=128)
result = evaluate(
model,
eval_dataset,
eval_batch_size=16,
output_dir='./outs'
)
print(result)
Contact us
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Close
Hashes for berttextclassification-0.0.5.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | d78e8a60eccb64d7461f1e59aca92553de8ee7dae8f263aa1af8bc801ec3559b |
|
MD5 | b77513c7105476dc512235c7ddbe4a7b |
|
BLAKE2b-256 | 67b82b89281973ecceb01b6bfb7d85bfd6ade2237dcfe1380c3276abb34a339f |