Text Generation Model
Project description
TextGen
🌈 Implementation of Text Generation models.
textgen实现了多种文本生成模型,包括:UDA、GPT2、Seq2Seq、BART、T5等模型,开箱即用。
Guide
Question
文本生成,文本数据增强怎么做?
Solution
- UDA,非核心词替换
- EDA,简单数据增强技术:相似词、同义词替换,随机词插入、删除、替换
- 回译(BT, Back Translate):中文-英文-中文
- 生成模型:Seq2Seq,GPT2,T5,BART等
Feature
- UDA(非核心词替换):本项目参考Google的UDA(非核心词替换)算法,基于TF-IDF将句子中部分不重要词替换为同义词,产生新的文本,实现了文本扩增
- BT(回译):本项目基于百度翻译API实现了回译功能,先把中文句子翻译为英文,再把英文翻译为新的中文
- Seq2Seq:本项目基于PyTorch实现了Seq2Seq、ConvSeq2Seq、BART模型的训练和预测,可以用于文本翻译、对话生成、摘要生成等文本生成任务
- T5:本项目基于PyTorch实现了T5模型训练和预测,可以用于文本翻译、对话生成、对联生成等文本生成任务
- GPT2:本项目基于PyTorch实现了GTP2模型训练和预测,可以用于文章生成、对联生成等文本生成任务
- TGLS:本项目实现了TGLS无监督相似文本生成模型,是一种“先搜索后学习”的文本生成方法,通过反复迭代学习候选集,最终模型能生成与候选集较高质量的相似文本
Demo
Demo: https://huggingface.co/spaces/shibing624/textgen
run example: examples/gradio_demo.py to see the demo:
python examples/gradio_demo.py
Install
pip3 install torch # conda install pytorch
pip3 install -U textgen
or
pip3 install torch # conda install pytorch
git clone https://github.com/shibing624/textgen.git
cd textgen
python3 setup.py install
Usage
Text Augmentation(EDA、UDA文本数据增强)
example: examples/text_augmentation_demo.py
import sys
sys.path.append('..')
from textgen.augment import TextAugment
if __name__ == '__main__':
docs = ['主要研究机器学习、深度学习、计算机视觉、智能对话系统相关内容',
'晚上肚子好难受',
'你会武功吗,我不会',
'组装标题质量受限于广告主自提物料的片段质量,且表达丰富度有限',
'晚上一个人好孤单,想:找附近的人陪陪我.',
]
m = TextAugment(sentence_list=docs)
a = docs[0]
print(a)
b = m.augment(a, aug_ops='random-0.1')
print('random-0.1:', b)
b = m.augment(a, aug_ops='insert-0.1')
print('insert-0.1:', b)
# tfidf
b = m.augment(a, aug_ops='tfidf-0.2')
print('tfidf-0.2:', b)
b = m.augment(a, aug_ops='mix-0.1', similar_prob=0.1,
random_prob=0.4, delete_prob=0.3, insert_prob=0.2)
print('mix-0.1:', b)
b = m.augment(a, aug_ops='bt')
print('bt:', b)
output:
主要研究机器学习、深度学习、计算机视觉、智能对话系统相关内容
random-0.1: ('主要的机器学习、深度学习吗计算机视觉、好孤单对话系统相关内容', [('研究', '的', 2, 3), ('、', '吗', 12, 13), ('智能', '好孤单', 19, 22)])
insert-0.1: ('主要研究机器机器学习、深度学习、计算机视觉、智能对话对话系统系统相关内容', [('机器', '机器机器', 4, 8), ('对话', '对话对话', 24, 28), ('系统', '系统系统', 28, 32)])
tfidf-0.2: ('主要原因研究机器学习、深度学习、计算机硬件视觉、智能化对话系统相关内容', [('主要', '主要原因', 0, 4), ('计算机', '计算机硬件', 16, 21), ('智能', '智能化', 24, 27)])
mix-0.1: ('主要受限于机器学习、深度学习、计算机视觉、智能对话系统相关内容', [('研究', '受限于', 2, 5)])
bt: ('主要研究机器学习、深度学习、计算机视觉和智能对话系统', [])
Seq2Seq 模型
ConvSeq2Seq
训练并预测ConvSeq2Seq模型:
example: examples/seq2sesq/training_convseq2seq_model_demo.py
import argparse
from loguru import logger
import sys
sys.path.append('../..')
from textgen.seq2seq.conv_seq2seq_model import ConvSeq2SeqModel
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--train_file', default='../data/zh_dialog.tsv', type=str, help='Training data file')
parser.add_argument('--do_train', action='store_true', help='Whether to run training.')
parser.add_argument('--do_predict', action='store_true', help='Whether to run predict.')
parser.add_argument('--output_dir', default='./outputs/convseq2seq_zh/', type=str, help='Model output directory')
parser.add_argument('--max_seq_length', default=50, type=int, help='Max sequence length')
parser.add_argument('--num_epochs', default=200, type=int, help='Number of training epochs')
parser.add_argument('--batch_size', default=32, type=int, help='Batch size')
args = parser.parse_args()
logger.info(args)
if args.do_train:
logger.info('Loading data...')
model = ConvSeq2SeqModel(epochs=args.num_epochs, batch_size=args.batch_size,
model_dir=args.output_dir, max_length=args.max_seq_length)
model.train_model(args.train_file)
print(model.eval_model(args.train_file))
if args.do_predict:
model = ConvSeq2SeqModel(epochs=args.num_epochs, batch_size=args.batch_size,
model_dir=args.output_dir, max_length=args.max_seq_length)
sentences = ["什么是ai", "你是什么类型的计算机", "你知道热力学吗"]
print("inputs:", sentences)
print('outputs:', model.predict(sentences))
if __name__ == '__main__':
main()
output:
inputs: ["什么是ai", "你是什么类型的计算机", "你知道热力学吗"]
outputs: ['人工智能是工程和科学的分支,致力于构建思维的机器。', '我的程序运行在python,所以我在任何运脑上工作!', '我不能错热是一个疯狂的人工智能"200年。']
BART
训练并预测BART模型:
example: examples/seq2sesq/training_bartseq2seq_zh_demo.py
output:
inputs: ['什么是ai', '你是什么类型的计算机', '你知道热力学吗']
outputs: ['人工智能是工程和科学的分支,致力于构', '我的程序运行在python,所以我在任何电脑上', '什么是热力学吗?']
T5 模型
example: examples/T5/training_zh_t5_model_demo.py
import argparse
from loguru import logger
import pandas as pd
import sys
sys.path.append('../..')
from textgen.t5 import T5Model
def load_data(file_path):
data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip('\n')
terms = line.split('\t')
if len(terms) == 2:
data.append(['QA', terms[0], terms[1]])
else:
logger.warning(f'line error: {line}')
return data
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--train_file', default='../data/zh_dialog.tsv', type=str, help='Training data file')
parser.add_argument('--model_type', default='t5', type=str, help='Transformers model type')
parser.add_argument('--model_name', default='Langboat/mengzi-t5-base', type=str, help='Transformers model or path')
parser.add_argument('--do_train', action='store_true', help='Whether to run training.')
parser.add_argument('--do_predict', action='store_true', help='Whether to run predict.')
parser.add_argument('--output_dir', default='./outputs/mengzi_t5_zh/', type=str, help='Model output directory')
parser.add_argument('--max_seq_length', default=50, type=int, help='Max sequence length')
parser.add_argument('--num_epochs', default=10, type=int, help='Number of training epochs')
parser.add_argument('--batch_size', default=32, type=int, help='Batch size')
args = parser.parse_args()
logger.info(args)
if args.do_train:
logger.info('Loading data...')
# train_data: Pandas DataFrame containing the 3 columns - `prefix`, `input_text`, `target_text`.
# - `prefix`: A string indicating the task to perform. (E.g. `"question"`, `"stsb"`)
# - `input_text`: The input text. `prefix` is prepended to form the full input. (<prefix>: <input_text>)
# - `target_text`: The target sequence
train_data = load_data(args.train_file)
logger.debug('train_data: {}'.format(train_data[:10]))
train_df = pd.DataFrame(train_data, columns=["prefix", "input_text", "target_text"])
eval_data = load_data(args.train_file)[:10]
eval_df = pd.DataFrame(eval_data, columns=["prefix", "input_text", "target_text"])
model_args = {
"reprocess_input_data": True,
"overwrite_output_dir": True,
"max_seq_length": args.max_seq_length,
"train_batch_size": args.batch_size,
"num_train_epochs": args.num_epochs,
"save_eval_checkpoints": False,
"save_model_every_epoch": False,
"evaluate_generated_text": True,
"evaluate_during_training": True,
"evaluate_during_training_verbose": True,
"use_multiprocessing": True,
"save_best_model": True,
"output_dir": args.output_dir,
"use_early_stopping": True,
}
# model_type: t5 model_name: Langboat/mengzi-t5-base
model = T5Model(args.model_type, args.model_name, args=model_args)
def count_matches(labels, preds):
logger.debug(f"labels: {labels[:10]}")
logger.debug(f"preds: {preds[:10]}")
match = sum([1 if label == pred else 0 for label, pred in zip(labels, preds)])
logger.debug(f"match: {match}")
return match
model.train_model(train_df, eval_data=eval_df, matches=count_matches)
print(model.eval_model(eval_df, matches=count_matches))
if args.do_predict:
model = T5Model(args.model_type, args.output_dir)
sentences = ["什么是ai", "你是什么类型的计算机", "你知道热力学吗"]
print("inputs:", sentences)
print("outputs:", model.predict(sentences))
if __name__ == '__main__':
main()
output:
inputs: ['什么是ai', '你是什么类型的计算机', '你知道热力学吗']
outputs: ['人工智能有两个广义的定义,任何拟人的机械,如在卡雷尔capeks', '我的程序运行在Python,所以我在任何电脑上工作!', '什么是热力学']
GPT2 模型
中文GPT2 - 文章生成
使用中文数据集(段落格式,\n
间隔),训练GPT2模型,可以用于诗歌生成、文章生成等任务。
example: examples/language_generation/training_zh_gpt2_demo.py
中文GPT2 - 对联生成
使用中文对联数据集(tsv格式,\t
间隔),自定义数据集读取Dataset,训练GPT2模型,可以用于对联生成、对话生成等任务。
example: examples/language_generation/training_couplet_gpt2_demo.py
- GPT2 vs T5:
- 都是从Transformer改进来的,T5同时有编码器和解码器,GPT2只有解码器
- T5的模型优势是处理给定输入,产出对应输出的任务,如翻译、对话、问答等
- GPT2的模型优势是自由创作,如写一篇短文
- T5的对联生成效果好于GPT2、GPT2的诗词生成效果好于T5
TGLS 模型(无监督相似文本生成模型)
无监督的中文电商评论生成:从电商评论中提取用户表达观点的短句并进行组合来生成仿真评论。
example: examples/unsup_generation_demo.py
import os
import sys
sys.path.append('..')
from textgen.unsup_generation import TglsModel
from textgen.unsup_generation.phrase import load_list
pwd_path = os.path.abspath(os.path.dirname(__file__))
samples = load_list(os.path.join(pwd_path, './data/ecommerce_comments.txt'))
docs_text = [
["挺好的,速度很快,也很实惠,不知效果如何",
"产品没得说,买了以后就降价,心情不美丽。",
"刚收到,包装很完整,不错",
"发货速度很快,物流也不错,同一时间买的两个东东,一个先到一个还在路上。这个水水很喜欢,不过盖子真的开了。盖不牢了现在。",
"包装的很好,是正品",
"被种草兰蔻粉水三百元一大瓶囤货,希望是正品好用,收到的时候用保鲜膜包裹得严严实实,只敢买考拉自营的护肤品",
],
['很温和,清洗的也很干净,不油腻,很不错,会考虑回购,第一次考拉买护肤品,满意',
'这款卸妆油我会无限回购的。即使我是油痘皮,也不会闷痘,同时在脸部按摩时,还能解决白头的脂肪粒的问题。用清水洗完脸后,非常的清爽。',
'自从用了fancl之后就不用其他卸妆了,卸的舒服又干净',
'买贵了,大润发才卖79。9。',
],
samples
]
m = TglsModel(docs_text)
r = m.generate(samples[:500])
print('size:', len(r))
for review in r:
print('\t' + review)
output:
美迪惠尔 N.M.F针剂水库保湿面膜有如下的20句评论,其中有10句是真实用户评论,10句是生成的评论,能看出来么?😂
还不错还不错还不错还不错。
东西到了,不知道好不好用。试用过后再来评价。到时看网评都还可以。
哺乳期唯一使用的护肤品,每天都是素颜,脸面全靠面膜吊着😄补水💦不粘腻一如既往的支持,喜欢💕
搞活动时买的面膜,不知道这个面膜是真是假敷在脸上面膜纸都有小水泡鼓起来。
很不错,非常补水,用过的都知道,性价比之王,好用又不贵,正品,用着放心,物流也很快。
面膜非常好用哦。面膜薄薄的。好像是蚕丝面膜啊。精华很多呢。敷在脸上很舒服。感觉挺保湿的,味道也挺好闻的。就是里面只有单纯的面膜直接敷脸上有点不好弄,哈哈哈
还可以保湿效果不错水润润的每天贴一片脸也不干了用完了在买点,不错还会继续回购的。
快递很快,东西很赞!想要得点考拉豆不容易,还要三十个字。时间宝贵,废话不说!用过了就知道了
挺好用的,朋友推荐来的
挺好用的,淡淡的,虽然不是很浓精华的感觉,但是效果也蛮好的。划算
不得不说美迪惠尔的面膜是我用过的最好的面膜之一😎补水效果非常好,没想到这么便宜的价格竟真的能买到真品。
保湿效果挺好的,面膜很好用。
期待好的产品。
一打开包装里面的精华刚刚好,用了补水补水效果不错,物流非常快。
皮肤很光滑😇比上去速度快三天就到了。
前两天皮肤干燥连续敷了两个晚上感觉还不错😂补水效果明显!可想而知精华液又多充足😍敷上以后凉凉的很舒服。
补水效果一般吧~但是我用的韩国背回来的面膜纸不算薄,希望好用会回购的,敷上脸感觉比较清爽~价格还不便宜。
希望好用,面膜用过了很好用,皮肤水嫩光滑白皙,补水不错,价格也合适。
就是精华液太少了,保湿效果不错。
面膜的补水效果非常好,保湿效果确实很赞,这个面膜相对于胶原蛋白和美白的那两款的面膜纸要厚一些,看着价格合适。
前10句是真实用户评论,后10句是生成的。
Contact
- Issue(建议) :
- 邮件我:xuming: xuming624@qq.com
- 微信我: 加我微信号:xuming624, 备注:姓名-公司名-NLP 进NLP交流群。
License
授权协议为 The Apache License 2.0,可免费用做商业用途。请在产品说明中附加textgen的链接和授权协议。
Contribute
项目代码还很粗糙,如果大家对代码有所改进,欢迎提交回本项目,在提交之前,注意以下两点:
- 在
tests
添加相应的单元测试 - 使用
python -m pytest
来运行所有单元测试,确保所有单测都是通过的
之后即可提交PR。
Reference
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.