Skip to main content

自然语言处理(NLP)

Project description

NLHappy

PyTorch Lightning Config: Hydra Template Spacy WanDB



📌   简介

nlhappy是一款集成了数据处理,模型训练,文本处理流程构建等各种功能的自然语言处理库,并且内置了各种任务的SOTA方案,相信通过nlhappy可以让你更愉悦的做各种nlp任务

它主要的依赖有

🚀   安装

安装nlhappy

推荐先去pytorch官网安装pytorch和对应cuda

# pip 安装
pip install --upgrade pip
pip install --upgrade nlhappy
其他可选

推荐安装wandb用于可视化训练日志

wandb login

模型训练开始后去官网查看训练实况

⚡   快速开始

任务示例的数据集获取:CBLUE官网

文本分类

制作数据集

from nlhappy.datamodules import TextClassificationDataModule
from nlhappy.utils.make_dataset import DatasetDict, Dataset
import srsly

# CBLUE短文本分类数据集CHIP-CTC为例

# 查看数据集格式
example = TextClassificationDataModule.show_one_example()
# example : {'label': '新闻', 'text': '怎么给这个图片添加超级链接呢?'}

# 制作数据集
train_data = list(srsly.read_json('assets/CHIP-CTC/CHIP-CTC_train.json'))
val_data = list(srsly.read_json('assets/CHIP-CTC/CHIP-CTC_dev.json'))
def convert_to_dataset(data):
    dataset = {'text':[], 'label':[]}
    for d in data:
        dataset['text'].append(d['text'])
        dataset['label'].append(d['label'])
    return Dataset.from_dict(dataset)
train_ds = convert_to_dataset(train_data)
val_ds = convert_to_dataset(val_data)
dataset_dict = DatasetDict({'train':train_ds, 'validation':val_ds})

# 保存数据集, datasets为nlhappy默认数据集路径
dataset_dict.save_to_disk('./datasets/CHIP-CTC')

训练模型

  • 命令行或bash脚本,方便快捷
nlhappy \
datamodule=text_classification \
datamodule.dataset=TNEWS \
# huffingface的预训练模型都会支持
datamodule.plm=hfl/chinese-roberta-wwm-ext \ 
datamodule.batch_size=32 \
model=bert_tc \
model.lr=3e-5 \
seed=1234 \
trainer=default
# 默认为单gpu 0号显卡训练,可以通过以下方式修改显卡
# trainer.devices=[1]
# 单卡半精度训练
# trainer.precision=16
# 使用wandb记录日志
# logger=wandb
# 多卡训练
# trainer=ddp trainer.devices=4

模型预测

from nlhappy.models import BertTextClassification
# 加载训练的checkpoint
ckpt = 'logs/path/***.ckpt'
model = BertTextClassification.load_from_ckeckpoint(ckpt)
text = '研究开始前30天内,接受过其他临床方案治疗;'
scores = model.predict(text=text, device='cpu')
# 转为onnx模型
model.to_onnx('path/tc.onnx')
model.tokenizer.save_pretrained('path/tokenizer')
实体抽取

nlhappy支持正常,嵌套和非连续的实体抽取任务,下面以可以解决嵌套任务的模型globalpointer为例

制作数据集

# CBLUE实体识别数据集CMeEE
from nlhappy.datamodules import EntityExtractionDataModule
from nlhappy.models.entity_extraction import GlobalPointerForEntityExtraction
from nlhappy.utils.make_dataset import Dataset, DatasetDict
import srsly

#查看数据集格式
example = EntityExtractionDataModule.show_one_example()
# example : {"text":"这是一个长颈鹿","entities":[{"indexes":[4,5,6],"label":"动物", "text":"长颈鹿"}]}

# 根据格式制作数据集
train_data = list(srsly.read_json('assets/CMeEE/CMeEE_train.json'))
val_data = list(srsly.read_json('assets/CMeEE/CMeEE_dev.json'))
def convert_to_dataset(data):
    ds = {'text':[],'entities':[]}
    for d in data:
        ents = []
        ds['text'].append(d['text'])
        for e in d['entities']:
            if len(e['entity'])>0:
                ent = {}
                ent['text'] = e['entity']
                ent['indexes'] = [idx for idx in range(len(e['entity']))]
                ent['text'] = e['entity']
                ent['label'] = e['type']
                ents.append(ent)
        ds['entities'].append(ents)
    return Dataset.from_dict(ds) 
train_ds = convert_to_dataset(train_data)
val_ds = convert_to_dataset(val_data)
dataset_dict = DatasetDict({'train':train_ds, 'validation':val_ds})
dataset_dict.save_to_disk('./datasets/CMeEE')

训练模型

  • 编写训练脚本
nlhappy \
datamodule=entity_extraction \
datamodule.dataset=CMeEE \
datamodule.plm=hfl/chinese-roberta-wwm-ext \
datamodule.batch_size=16 \
model=globalpointer \
model.lr=3e-5 \
seed=123
# 默认trainer为单卡训练
trainer=default

# 多卡
# trainer=ddp \
# trainer.devices=4 \
# trainer.precision=16 #半精度

# 如果安装了wandb则可以使用wandb
# logger=wandb 

模型推理

from nlhappy.models import GlobalPointer
ckpt = 'logs/path/***.ckpt'
model = GlobalPointer.load_from_ckeckpoint(ckpt)
ents = model.predcit('文本')

# 转为onnx模型,进行后续部署
model.to_onnx('path/tc.onnx')
model.tokenizer.save_pretrained('path/tokenizer')
关系抽取 TODO
事件抽取 TODO
摘要 TODO
翻译 TODO

论文复现

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nlhappy-2022.10.20.tar.gz (103.7 kB view hashes)

Uploaded Source

Built Distribution

nlhappy-2022.10.20-py3-none-any.whl (160.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page