Skip to main content

Pytorch-lightning Code Blocks for Information Extraction

Project description



此项目为开源文本分类、实体抽取、关系抽取和事件抽取模型的训练和推理提供统一的框架,具有以下特性

  • ✨ 支持多种开源文本分类、实体抽取、关系抽取和事件抽取模型

  • 👑 支持百度 UIE 模型的训练和推理

  • 🚀 统一的训练和推理框架

  • 🎯 集成对抗训练方法,简便易用

📢 News

  • 【2023.6.21】 增加文本分类代码示例

  • 【2023.6.19】 增加 gplinker 事件抽取模型和代码示例

  • 【2023.6.15】 增加对抗训练功能和示例、增加 onerel 关系抽取模型

  • 【2023.6.14】 新增 UIE 模型代码示例


📦 安装

pip 安装

pip install --upgrade litie

源码安装

git clone https://github.com/xusenlinzy/lit-ie

pip install -r requirements.txt

🐼 模型

实体抽取

模型 论文 备注
softmax 全连接层序列标注并使用 BIO 解码
crf Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data 全连接层+条件随机场,并使用 BIO 解码
cascade-crf 先预测实体再预测实体类型
span 使用两个指针网络预测实体起始位置
global-pointer GlobalPointer:用统一的方式处理嵌套和非嵌套NEREfficient GlobalPointer:少点参数,多点效果
mrc A Unified MRC Framework for Named Entity Recognition. 将实体识别任务转换为阅读理解问题,输入为实体类型模板+句子,预测对应实体的起始位置
tplinker TPLinker: Single-stage Joint Extraction of Entities and Relations Through Token Pair Linking. 将实体识别任务转换为表格填充问题
lear Enhanced Language Representation with Label Knowledge for Span Extraction. 改进 MRC 方法效率问题,采用标签融合机制
w2ner Unified Named Entity Recognition as Word-Word Relation Classification. 统一解决嵌套实体、不连续实体的抽取问题
cnn An Embarrassingly Easy but Strong Baseline for Nested Named Entity Recognition. 改进 W2NER 方法,采用卷积网络提取实体内部token之间的关系

关系抽取

模型 论文 备注
casrel A Novel Cascade Binary Tagging Framework for Relational Triple Extraction. 两阶段关系抽取,先抽取出句子中的主语,再通过指针网络抽取出主语对应的关系和宾语
tplinker TPLinker: Single-stage Joint Extraction of Entities and Relations Through Token Pair Linking. 将关系抽取问题转换为主语-宾语的首尾连接问题
spn Joint Entity and Relation Extraction with Set Prediction Networks. 将关系抽取问题转为为三元组的集合预测问题,采用 Encoder-Decoder 框架
prgc PRGC: Potential Relation and Global Correspondence Based Joint Relational Triple Extraction. 先抽取句子的潜在关系类型,然后对于给定关系抽取出对应的主语-宾语对,最后通过全局对齐模块过滤
pfn A Partition Filter Network for Joint Entity and Relation Extraction. 采用类似 LSTM 的分区过滤机制,将隐藏层信息分成实体识别、关系识别和共享三部分,对与不同的任务利用不同的信息
grte A Novel Global Feature-Oriented Relational Triple Extraction Model based on Table Filling. 将关系抽取问题转换为单词对的分类问题,基于全局特征抽取模块循环优化单词对的向量表示
gplinker GPLinker:基于GlobalPointer的实体关系联合抽取

📚 数据

实体抽取

将数据集处理成以下 json 格式

{
  "text": "结果上周六他们主场0:3惨败给了中游球队瓦拉多利德,近7个多月以来西甲首次输球。", 
  "entities": [
    {
      "id": 0, 
      "entity": "瓦拉多利德", 
      "start_offset": 20, 
      "end_offset": 25, 
      "label": "organization"
    }, 
    {
      "id": 1, 
      "entity": "西甲", 
      "start_offset": 33, 
      "end_offset": 35, 
      "label": "organization"
    }
  ]
}

字段含义:

  • text: 文本内容

  • entities: 该文本所包含的所有实体

    • id: 实体 id

    • entity: 实体名称

    • start_offset: 实体开始位置

    • end_offset: 实体结束位置的下一位

    • label: 实体类型

关系抽取

将数据集处理成以下 json 格式

{
  "text": "查尔斯·阿兰基斯(Charles Aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部", 
  "spo_list": [
    {
      "predicate": "出生地",
      "object": "圣地亚哥", 
      "subject": "查尔斯·阿兰基斯"
    }, 
    {
      "predicate": "出生日期",
      "object": "1989年4月17日",
      "subject": "查尔斯·阿兰基斯"
    }
  ]
}

字段含义:

  • text: 文本内容

  • spo_list: 该文本所包含的所有关系三元组

    • subject: 主体名称

    • object: 客体名称

    • predicate: 主体和客体之间的关系

事件抽取

将数据集处理成以下 json 格式

{
  "text": "油服巨头哈里伯顿裁员650人 因美国油气开采活动放缓",
  "id": "f2d936214dc2cb1b873a75ee29a30ec9",
  "event_list": [
    {
      "event_type": "组织关系-裁员",
      "trigger": "裁员",
      "trigger_start_index": 8,
      "arguments": [
        {
          "argument_start_index": 0,
          "role": "裁员方",
          "argument": "油服巨头哈里伯顿"
        },
        {
          "argument_start_index": 10,
          "role": "裁员人数",
          "argument": "650人"
        }
      ],
      "class": "组织关系"
    }
  ]
}

字段含义:

  • text: 文本内容

  • event_list: 该文本所包含的所有事件

    • event_type: 事件类型

    • trigger: 触发词

    • trigger_start_index: 触发词开始位置

    • arguments: 论元

      • role: 论元角色

      • argument: 论元名称

      • argument_start_index: 论元名称开始位置

🚀 模型训练

实体抽取

import os

from litie.arguments import (
    DataTrainingArguments,
    TrainingArguments,
)
from litie.models import AutoNerModel

os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'

training_args = TrainingArguments(
    other_learning_rate=2e-3,
    num_train_epochs=20,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    output_dir="outputs/crf",
)

data_args = DataTrainingArguments(
    dataset_name="datasets/cmeee",
    train_file="train.json",
    validation_file="dev.json",
    preprocessing_num_workers=16,
)

# 1. create model
model = AutoNerModel(
    task_model_name="crf",
    model_name_or_path="hfl/chinese-roberta-wwm-ext",
    training_args=training_args,
)

# 2. finetune model
model.finetune(data_args)

训练脚本详见 named_entity_recognition

关系抽取

import os

from litie.arguments import (
    DataTrainingArguments,
    TrainingArguments,
)
from litie.models import AutoRelationExtractionModel

os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'

training_args = TrainingArguments(
    num_train_epochs=20,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    output_dir="outputs/gplinker",
)

data_args = DataTrainingArguments(
    dataset_name="datasets/duie",
    train_file="train.json",
    validation_file="dev.json",
    preprocessing_num_workers=16,
)

# 1. create model
model = AutoRelationExtractionModel(
    task_model_name="gplinker",
    model_name_or_path="hfl/chinese-roberta-wwm-ext",
    training_args=training_args,
)

# 2. finetune model
model.finetune(data_args, num_sanity_val_steps=0)

训练脚本详见 relation_extraction

事件抽取

import os
import json

from litie.arguments import DataTrainingArguments, TrainingArguments
from litie.models import AutoEventExtractionModel

os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'

schema_path = "datasets/duee/schema.json"

labels = []
with open("datasets/duee/schema.json") as f:
    for l in f:
        l = json.loads(l)
        t = l["event_type"]
        for r in ["触发词"] + [s["role"] for s in l["role_list"]]:
            labels.append(f"{t}@{r}")

training_args = TrainingArguments(
    num_train_epochs=200,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    output_dir="outputs/gplinker",
)

data_args = DataTrainingArguments(
    dataset_name="datasets/duee",
    train_file="train.json",
    validation_file="dev.json",
    preprocessing_num_workers=16,
    train_max_length=128,
)

# 1. create model
model = AutoEventExtractionModel(
    task_model_name="gplinker",
    model_name_or_path="hfl/chinese-roberta-wwm-ext",
    training_args=training_args,
)

# 2. finetune model
model.finetune(
    data_args,
    labels=labels,
    num_sanity_val_steps=0,
    monitor="val_argu_f1",
    check_val_every_n_epoch=20,
)

训练脚本详见 event_extraction

📊 模型推理

实体抽取

from litie.pipelines import NerPipeline

task_model = "crf"
model_name_or_path = "path of crf model"
pipeline = NerPipeline(task_model, model_name_or_path=model_name_or_path)

print(pipeline("结果上周六他们主场0:3惨败给了中游球队瓦拉多利德,近7个多月以来西甲首次输球。"))

web demo

from litie.ui import NerPlayground

NerPlayground().launch()

关系抽取

from litie.pipelines import RelationExtractionPipeline

task_model = "gplinker"
model_name_or_path = "path of gplinker model"
pipeline = RelationExtractionPipeline(task_model, model_name_or_path=model_name_or_path)

print(pipeline("查尔斯·阿兰基斯(Charles Aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部"))

web demo

from litie.ui import RelationExtractionPlayground

RelationExtractionPlayground().launch()

事件抽取

from litie.pipelines import EventExtractionPipeline

task_model = "gplinker"
model_name_or_path = "path of gplinker model"
pipeline = EventExtractionPipeline(task_model, model_name_or_path=model_name_or_path)

print(pipeline("油服巨头哈里伯顿裁员650人 因美国油气开采活动放缓"))

web demo

from litie.ui import EventExtractionPlayground

EventExtractionPlayground().launch()

🍭 通用信息抽取

  • UIE(Universal Information Extraction):Yaojie Lu等人在ACL-2022中提出了通用信息抽取统一框架 UIE

  • 该框架实现了实体抽取、关系抽取、事件抽取、情感分析等任务的统一建模,并使得不同任务间具备良好的迁移和泛化能力。

  • PaddleNLP借鉴该论文的方法,基于 ERNIE 3.0 知识增强预训练模型,训练并开源了首个中文通用信息抽取模型 UIE

  • 该模型可以支持不限定行业领域和抽取目标的关键信息抽取,实现零样本快速冷启动,并具备优秀的小样本微调能力,快速适配特定的抽取目标。

模型训练

模型训练脚本详见 uie

模型推理

👉 命名实体识别
from pprint import pprint
from litie.pipelines import UIEPipeline

# 实体识别
schema = ['时间', '选手', '赛事名称'] 
# uie-base模型已上传至huggingface,可自动下载,其他模型只需提供模型名称将自动进行转换
uie = UIEPipeline("xusenlin/uie-base", schema=schema)
pprint(uie("2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!")) # Better print results using pprint

output:

[
  {
    "时间": [
      {
        "end": 6,
        "probability": 0.98573786,
        "start": 0,
        "text": "2月8日上午"
      }
    ],
    "赛事名称": [
      {
        "end": 23,
        "probability": 0.8503085,
        "start": 6,
        "text": "北京冬奥会自由式滑雪女子大跳台决赛"
      }
    ],
    "选手": [
      {
        "end": 31,
        "probability": 0.8981544,
        "start": 28,
        "text": "谷爱凌"
      }
    ]
  }
]
👉 实体关系抽取
from pprint import pprint
from litie.pipelines import UIEPipeline

# 关系抽取
schema = {'竞赛名称': ['主办方', '承办方', '已举办次数']}
# uie-base模型已上传至huggingface,可自动下载,其他模型只需提供模型名称将自动进行转换
uie = UIEPipeline("xusenlin/uie-base", schema=schema)
pprint(uie("2022语言与智能技术竞赛由中国中文信息学会和中国计算机学会联合主办,百度公司、中国中文信息学会评测工作委员会和中国计算机学会自然语言处理专委会承办,已连续举办4届,成为全球最热门的中文NLP赛事之一。")) # Better print results using pprint

output:

[
  {
    "竞赛名称": [
      {
        "end": 13,
        "probability": 0.78253937,
        "relations": {
          "主办方": [
            {
              "end": 22,
              "probability": 0.8421704,
              "start": 14,
              "text": "中国中文信息学会"
            },
            {
              "end": 30,
              "probability": 0.75807965,
              "start": 23,
              "text": "中国计算机学会"
            }
          ],
          "已举办次数": [
            {
              "end": 82,
              "probability": 0.4671307,
              "start": 80,
              "text": "4届"
            }
          ],
          "承办方": [
            {
              "end": 55,
              "probability": 0.700049,
              "start": 40,
              "text": "中国中文信息学会评测工作委员会"
            },
            {
              "end": 72,
              "probability": 0.61934763,
              "start": 56,
              "text": "中国计算机学会自然语言处理专委会"
            },
            {
              "end": 39,
              "probability": 0.8292698,
              "start": 35,
              "text": "百度公司"
            }
          ]
        },
        "start": 0,
        "text": "2022语言与智能技术竞赛"
      }
    ]
  }
]
👉 事件抽取
from pprint import pprint
from litie.pipelines import UIEPipeline

# 事件抽取
schema = {"地震触发词": ["地震强度", "时间", "震中位置", "震源深度"]}
# uie-base模型已上传至huggingface,可自动下载,其他模型只需提供模型名称将自动进行转换
uie = UIEPipeline("xusenlin/uie-base", schema=schema)
pprint(uie("中国地震台网正式测定:5月16日06时08分在云南临沧市凤庆县(北纬24.34度,东经99.98度)发生3.5级地震,震源深度10千米。")) # Better print results using pprint

output:

[
  {
    "地震触发词": [
      {
        "end": 58,
        "probability": 0.99774253,
        "relations": {
          "地震强度": [
            {
              "end": 56,
              "probability": 0.9980802,
              "start": 52,
              "text": "3.5级"
            }
          ],
          "时间": [
            {
              "end": 22,
              "probability": 0.98533,
              "start": 11,
              "text": "5月16日06时08分"
            }
          ],
          "震中位置": [
            {
              "end": 50,
              "probability": 0.7874015,
              "start": 23,
              "text": "云南临沧市凤庆县(北纬24.34度,东经99.98度)"
            }
          ],
          "震源深度": [
            {
              "end": 67,
              "probability": 0.9937973,
              "start": 63,
              "text": "10千米"
            }
          ]
        },
        "start": 56,
        "text": "地震"
      }
    ]
  }
]
👉 评论观点抽取
from pprint import pprint
from litie.pipelines import UIEPipeline

# 评论观点抽取
schema = {'评价维度': ['观点词', '情感倾向[正向,负向]']}
# uie-base模型已上传至huggingface,可自动下载,其他模型只需提供模型名称将自动进行转换
uie = UIEPipeline("xusenlin/uie-base", schema=schema)
pprint(uie("店面干净,很清静,服务员服务热情,性价比很高,发现收银台有排队")) # Better print results using pprint

output:

[
  {
    "评价维度": [
      {
        "end": 20,
        "probability": 0.98170394,
        "relations": {
          "情感倾向[正向,负向]": [
            {
              "probability": 0.9966142773628235,
              "text": "正向"
            }
          ],
          "观点词": [
            {
              "end": 22,
              "probability": 0.95739645,
              "start": 21,
              "text": "高"
            }
          ]
        },
        "start": 17,
        "text": "性价比"
      },
      {
        "end": 2,
        "probability": 0.9696847,
        "relations": {
          "情感倾向[正向,负向]": [
            {
              "probability": 0.9982153177261353,
              "text": "正向"
            }
          ],
          "观点词": [
            {
              "end": 4,
              "probability": 0.9945317,
              "start": 2,
              "text": "干净"
            }
          ]
        },
        "start": 0,
        "text": "店面"
      }
    ]
  }
]
👉 情感分类
from pprint import pprint
from litie.pipelines import UIEPipeline

# 事件抽取
schema = '情感倾向[正向,负向]'
# uie-base模型已上传至huggingface,可自动下载,其他模型只需提供模型名称将自动进行转换
uie = UIEPipeline("xusenlin/uie-base", schema=schema)
pprint(uie("这个产品用起来真的很流畅,我非常喜欢")) # Better print results using pprint

output:

[
  {
    "情感倾向[正向,负向]": [
      {
        "probability": 0.9990023970603943,
        "text": "正向"
      }
    ]
  }
]

📜 License

此项目为 Apache 2.0 许可证授权,有关详细信息,请参阅 LICENSE 文件。

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

litie-0.2.5-py3-none-any.whl (226.4 kB view details)

Uploaded Python 3

File details

Details for the file litie-0.2.5-py3-none-any.whl.

File metadata

  • Download URL: litie-0.2.5-py3-none-any.whl
  • Upload date:
  • Size: 226.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.13

File hashes

Hashes for litie-0.2.5-py3-none-any.whl
Algorithm Hash digest
SHA256 01c73336938c704c1df74c3be6319af9a6bd971ae468dad07798cd62ff6f395a
MD5 6413835a66ce17bf2fb20ca22eeeebb1
BLAKE2b-256 8443a16636bdf3b4ccd939499e7b4a9ef30d8341bd1c36d5a76e2f9e541f264b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page