PyTorch-based knowledge distillation toolkit for natural language processing tasks
Project description
TextBrewer
(当前版本: 0.1.6)
TextBrewer 是一个基于PyTorch的、为NLP中的知识蒸馏任务设计的工具包。
TextBrewer 的主要特点有:
- 方便灵活:适用于多种模型结构(主要面向Transfomer结构)
- 易于扩展:诸多蒸馏参数可调,支持增加自定义损失等模块
- 非侵入式:无需对教师与学生模型本身结构进行修改
- 支持典型的NLP任务:文本分类、阅读理解、序列标注等
其主要模块与功能分为3块:
-
Distillers:进行蒸馏的核心部件。不同的distiller提供不同的蒸馏模式。有 GeneralDistiller, MultiTeacherDistiller, MultiTaskDistiller 等
-
Configurations and presets:训练与蒸馏方法的配置,以及预定义蒸馏策略
-
Utilities:模型参数分析显示等辅助工具
要开始知识蒸馏,用户需要提供的有:
- 已训练好的教师模型, 待蒸馏的学生模型
- 训练数据与必要的实验配置
在多个典型NLP任务上,TextBrewer都能取得较好的压缩效果。相关实验见蒸馏效果。
安装
安装要求
- Python >= 3.6
- PyTorch >= 1.1.0
- TensorboardX or Tensorboard
安装方式
- 从PyPI安装:
pip install textbrewer
- 从源码文件夹安装:
pip install ./textbrewer
工作流程
-
Stage 1 : 在开始蒸馏之前,需要做一些准备工作 :
- 训练教师模型
- 定义与初始化学生模型(随机初始化,或载入预训练权重)
- 构造蒸馏用数据集的dataloader,训练学生模型用的optimizer和learning rate scheduler
-
Stage 2 : 开始蒸馏:
- 初始化distiller,构造训练配置(TrainingConfig)和蒸馏配置(DistillationConfig)
- 定义adaptors 和 callback,分别用与适配模型输入输出和训练过程中的回调
- 调用distiller的train方法开始蒸馏
用户应先实现 stage 1 ,得到训练好的教师模型;TextBrewer 主要负责 Stage 2的蒸馏工作。
(TextBrewer中也提供了用于 stage 1 的 BasicTrainer,用于训练教师模型)
下面展示一个简单的例子,包含了TextBrewer的基本用法
快速开始
以蒸馏基于BERT的文本分类模型为例,数据为随机数据
(一些概念的进一步解释和相关参数配置见详细文档)
- Stage 1 : 准备工作
import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig
import torch
from torch.utils.data import Dataset,DataLoader
# 用transformers构建教师与学生模型
# 这段代码中用的transformers == 2.3
from transformers import BertForSequenceClassification, BertConfig, AdamW
# 运行设备
device = torch.device('cpu')
# 定义模型
# bert_config 是 12层BERT-base的配置
# bert_config_T3 是截至3层的BERT
bert_config = BertConfig.from_json_file('bert_config/bert_config.json')
bert_config_T3 = BertConfig.from_json_file('bert_config/bert_config_T3.json')
# 使用hidden_states作为中间输出特征
bert_config.output_hidden_states = True
bert_config_T3.output_hidden_states = True
# 定义教师模型
teacher_model = BertForSequenceClassification(bert_config) #, num_labels = 2
# 教师模型应当被合理初始化,如载入预训练权重,并在下游任务上训练
# 这里出于演示目的,省略相关步骤
# 定义学生模型
student_model = BertForSequenceClassification(bert_config_T3) #, num_labels = 2
teacher_model.to(device=device)
student_model.to(device=device)
#支持 DataParallel并行
#teacher_model = torch.nn.DataParallel(teacher_model)
#student_model = torch.nn.DataParallel(student_model)
# 定义字典形式的Dataset,字典的key和model的forward方法的参数名匹配
# 也可用PyTorch自带的TensorDataset构造数据集,但要注意元素顺序和model的forward的参数顺序一致
class DictDataset(Dataset):
def __init__(self, all_input_ids, all_attention_mask, all_labels):
assert len(all_input_ids)==len(all_attention_mask)==len(all_labels)
self.all_input_ids = all_input_ids
self.all_attention_mask = all_attention_mask
self.all_labels = all_labels
def __getitem__(self, index):
return {'input_ids': self.all_input_ids[index],
'attention_mask': self.all_attention_mask[index],
'labels': self.all_labels[index]}
def __len__(self):
return self.all_input_ids.size(0)
# 准备一些随机的数据
all_input_ids = torch.randint(low=0,high=100,size=(100,128))
all_attention_mask = torch.ones_like(all_input_ids)
all_labels = torch.randint(low=0,high=2,size=(100,))
dataset = DictDataset(all_input_ids, all_attention_mask, all_labels)
eval_dataset = DictDataset(all_input_ids, all_attention_mask, all_labels)
dataloader = DataLoader(dataset,batch_size=32)
# 初始化Optimizer和learning rate scheduler.
# Learning rate scheduler 可以为None.
optimizer = AdamW(student_model.parameters(), lr=1e-4)
scheduler = None
- Stage 2 : 使用TextBrewer蒸馏:
######蒸馏相关准备#########
# 展示模型参数量的统计
print("\nteacher_model's parametrers:")
_ = textbrewer.utils.display_parameters(teacher_model,max_level=3)
print("student_model's parametrers:")
_ = textbrewer.utils.display_parameters(student_model,max_level=3)
# 定义 adaptor, 用于解释模型的输出,这里返回了logits和hidden_states
def simple_adaptor(batch, model_outputs):
# model_outputs 的第二个元素是softmax之前的logits
# model_outputs 的第三个元素是hidden states,
# model_outputs[2][i] 是模型第i层的hidden state;
# model_outputs[2][0] 是模型的embedding。
# 具体的输出见transformers的相关说明
return {'logits': model_outputs[1],
'hidden': model_outputs[2]}
#定义回调函数, 也可以为None
#model和step分别是学生模型和当前训练步数
#这里的例子为:callback用于在验证集上测试模型
def predict(model, eval_dataset, step, device):
# eval_dataset: 验证数据集
model.eval()
pred_logits = []
label_ids =[]
dataloader = DataLoader(eval_dataset,batch_size=32)
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels']
with torch.no_grad():
logits, _ = model(input_ids=input_ids, attention_mask=attention_mask)
cpu_logits = logits.detach().cpu()
for i in range(len(cpu_logits)):
pred_logits.append(cpu_logits[i].numpy())
label_ids.append(labels[i])
model.train()
pred_logits = np.array(pred_logits)
label_ids = np.array(label_ids)
y_p = pred_logits.argmax(axis=-1)
accuracy = (y_p==label_ids).sum()/len(label_ids)
print ("Number of examples: ",len(y_p))
print ("Acc: ", accuracy)
from functools import partial
# 填充多余的参数
callback_fun = partial(predict, eval_dataset=eval_dataset, device=device)
# 初始化配置
# 训练配置
train_config = TrainingConfig(device=device)
# 蒸馏配置
distill_config = DistillationConfig(
temperature=8, # 温度
hard_label_weight=0, # hard_label_loss的权重
kd_loss_type='ce', # kd_loss 设为 交叉熵
intermediate_matches=[ # 中间层特征匹配策略
{'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},
{'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse', 'weight' : 1},
{'layer_T':[0,0], 'layer_S':[0,0], 'feature':'hidden', 'loss': 'nst', 'weight': 1},
{'layer_T':[8,8], 'layer_S':[2,2], 'feature':'hidden', 'loss': 'nst', 'weight': 1}])
print ("train_config:")
print (train_config)
print ("distill_config:")
print (distill_config)
#初始化distiller
distiller = GeneralDistiller(
train_config=train_config, distill_config = distill_config,
model_T = teacher_model, model_S = student_model,
adaptor_T = simple_adaptor, adaptor_S = simple_adaptor)
# 开始蒸馏
distiller.train(optimizer, scheduler, dataloader, num_epochs=1, callback=callback_fun)
蒸馏效果
我们在多个中英文文本分类、阅读理解、序列标注数据集上进行了蒸馏实验。实验的配置和效果如下。
英文数据集
Dataset | Task type | Metrics | #Train | #Dev |
---|---|---|---|---|
MNLI | 文本分类 | Acc | 393K | 20K |
SQuAD1.1 | 阅读理解 | EM/F1 | 88K | 11K |
CoNLL-2003 | 序列标注 | F1 | 23K | 6K |
-
MNL是句对三分类(entailment,neutral,contradictory)任务。
-
SQuAD1.1是抽取式阅读理解任务,要求从篇章中抽取片段作为问题的答案。
-
CoNLL-2003是命名实体识别任务,需要标记出句子中每个词的实体类型。
中文数据集
Dataset | Task type | Metrics | #Train | #Dev |
---|---|---|---|---|
XNLI | 文本分类 | Acc | 393K | 2.5K |
LCQMC | 文本分类 | Acc | 239K | 8.8K |
CMRC2018 | 阅读理解 | EM/F1 | 10K | 3.4K |
DRCD | 阅读理解 | EM/F1 | 27K | 3.5K |
- XNLI 是MNLI的中文翻译版本,同样为3分类任务。
- LCQMC由哈工大深圳研究生院智能计算研究中心发布的句对二分类任务, 判断两个句子的语义是否相同。
- CMRC 2018是哈工大讯飞联合实验室发布的中文机器阅读理解数据集。 形式与SQuAD相同。
- DRCD由中国台湾台达研究院发布的基于繁体中文的抽取式阅读理解数据集。其形式与SQuAD相同。
模型
对于英文任务,教师模型为BERT-base; 对于中文任务,教师模型为HFL发布的RoBERTa-wwm-ext
我们测试了不同的学生模型,除了BiGRU都是和BERT一样的多层transformer结构。模型的参数如下
Model | #Layers | Hidden_size | Feed-forward size | #Params | Relative size |
---|---|---|---|---|---|
BERT-base (Teacher) | 12 | 768 | 3072 | 108M | 100% |
RoBERTa-wwm (Teacher) | 12 | 768 | 3072 | 108M | 100% |
T6 | 6 | 768 | 3072 | 65M | 60% |
T3 | 3 | 768 | 3072 | 44M | 41% |
T3-small | 3 | 384 | 1536 | 17M | 16% |
T4-Tiny | 4 | 312 | 1200 | 14M | 13% |
BiGRU | - | 768 | - | 31M | 29% |
参数量的统计包括了embedding层,但不包括最终适配各个任务的输出层
蒸馏配置
distill_config = DistillationConfig(
temperature = 8,
intermediate_matches = matches)
# 其他参数为默认值
不同的模型用的matches我们采用了以下配置:
Model | matches |
---|---|
BiGRU | None |
T6 | L6_hidden_mse + L6_hidden_smmd |
T3 | L3_hidden_mse + L3_hidden_smmd (英文任务上) 或 L3_hidden_mse (中文任务上) |
T3-small | L3n_hidden_mse + L3_hidden_smmd |
T4-Tiny | L4t_hidden_mse + L4_hidden_smmd |
各种matches的定义在exmaple/matches/matches.py文件中。均使用GeneralDistiller进行蒸馏。
训练配置
蒸馏用的学习率 lr=1e-4(除非特殊说明)。训练30~60轮。学习率衰减模式和BERT一致(10%的warmup,90%的linear decay)
英文实验结果
Model | MNLI (m/mm Acc) | SQuAD (EM/F1) | CoNLL-2003 (F1) |
---|---|---|---|
BERT-base | 83.7 / 84.0 | 81.5 / 88.6 | 91.1 |
BiGRU | - | - | 85.3 |
T6 | 83.5 / 84.0 (30) | 88.1 / 80.8 (50) | 90.7 |
T3 | 81.8 / 82.7 (40) | 76.4 / 84.9 (50) | 87.5 (30) |
T3-small | 81.3 / 81.7 (40) | 72.3 / 81.4 (50) | 57.4 |
T4-tiny | 82.0 / 82.6 (40) | 73.7 / 82.5 (50) | 54.7 |
+ 数据增强 | - | 75.2 / 84.0 (50) | 79.6 |
说明:
- 括号内为训练轮数。
- SQuAD任务上用的是NewsQA作为增强数据;CoNLL-2003上用的是HotpotQA的篇章作为增强数据。数据增强对蒸馏效果有明显的提升作用。
中文实验结果
Model | XNLI (Acc) | LCQMC (Acc) | CMRC2018 (EM/F1) | DRCD (EM/F1) |
---|---|---|---|---|
RoBERTa-wwm | 79.9 | 89.4 | 68.8 / 86.4 | 86.5 / 92.5 |
T3 | 78.4 (30) | 89.0 (30) | 63.4 / 82.4 (50) | 76.7 / 85.2 (60) |
+ 数据增强 | 66.4 / 84.2 (50) | 78.2 / 86.4 (60) | ||
T3-small | 76.0 (30) | 88.1 (30) | 24.4 / 48.1 (50) | 42.2 / 63.2 (40) |
+ 数据增强 | - | - | 58.0 / 79.3 (50) | 65.5 / 78.6 (60) |
T4-tiny | 76.2 (30) | 88.4 (30) | - | - |
+ 数据增强 | - | - | 61.8 / 81.8 (50) | 73.3 / 83.5 (60) |
说明:
- 括号内为训练轮数。
- 蒸馏CMRC2018和DRCD上的模型时学习率分别为1.5e-4和7e-5。
- 蒸馏到T3的实验中,XNLI和LCQMC使用的matches是L3_hidden_mse;其他所有实验使用的均是L3_hidden_mse + L3_hidden_smmd。
- 针对CMRC2018和DRCD数据集的蒸馏实验,不采用学习率衰减:学习率从增长到指定值后一直保持不变。
- 在使用了数据增强的实验中,CMRC2018和DRCD互作为增强数据。可以发现在训练集较小且模型随机初始化的情况下,数据增强的提升作用明显。
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.