Skip to main content

Named Entity Recognition

Project description

Usage Sample ''''''''''''

.. code:: python

import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from nerx import NER, Collator
from model_wrapper.dataset import DictDataset
from model_wrapper import ClassifyModelWrapper

pretrained_path = "nghuyong/ernie-3.0-base-zh"
classes = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'PADDING']
num_classes = len(classes)
    
def f(data):
    return 5 < len(data['tokens']) <= 512 - 2

dataset_dict = load_from_disk('/kaggle/input/peoples-daily-ner-data/peoples_daily_ner')
train_set = dataset_dict['train'].remove_columns(['id']).filter(f, cache_file_name='/kaggle/working/train.cache')
val_set = dataset_dict['validation'].remove_columns(['id']).filter(f, cache_file_name='/kaggle/working/val.cache')
test_set = dataset_dict['test'].remove_columns(['id']).filter(f, cache_file_name='/kaggle/working/test.cache')    
train_set = DictDataset(train_set, 'tokens', 'ner_tags')
val_set = DictDataset(val_set, 'tokens', 'ner_tags')

model = NER(pretrained_path, num_classes=num_classes, num_train_layers=2)
wrapper = ClassifyModelWrapper(model)
tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
history = wrapper.train(train_set, val_set, collate_fn=Collator(tokenizer, num_classes - 1))
wrapper.save_state_dict(mode='best')

def display(tags, text, classes):
    padding_idx = len(classes) - 1
    start_index, start_tag = -1, -1
    for i, tag in enumerate(tags):
        if tag == padding_idx:
            if start_index != -1:
                print(f"{start_index}-{i}", ' ', classes[start_tag].split('-')[1], ' ', ''.join(text[start_index:i]))
            break    
        if 0 < tag:
            if start_index == -1 and 0 < tag:
                start_index, start_tag = i, tag
                continue
                    
            if start_tag != tag - 1 and start_tag != tag:
                print(f"{start_index}-{i}", ' ', classes[start_tag].split('-')[1], ' ', ''.join(text[start_index:i]))
                start_index, start_tag = i, tag 
        else:
            if start_index > -1:
                print(f"{start_index}-{i}", ' ', classes[start_tag].split('-')[1], ' ', ''.join(text[start_index:i]))
                start_index, start_tag = -1, -1

def test(data, model):
    M, N = 50, 30
    text, label = data['tokens'], data['ner_tags']
    tokens = tokenizer.batch_encode_plus([text],
                                    max_length=256,
                                    padding=True,
                                    truncation=True,
                                    return_tensors='pt',
                                    return_token_type_ids=False,
                                    is_split_into_words=True)
    model.eval()
    with torch.inference_mode():
        result = model(tokens)[0]
    print('=' * M, "原文", '=' * M)
    print(''.join(text))
    print('-' * N, "标注",'-' * N)
    display(label, text, classes)
    print('-' * N, "预测",'-' * N)
    display(result, text, classes)    

for i in range(20):
    test(test_set[i], model)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

NERX-0.2.1.tar.gz (8.4 kB view details)

Uploaded Source

File details

Details for the file NERX-0.2.1.tar.gz.

File metadata

  • Download URL: NERX-0.2.1.tar.gz
  • Upload date:
  • Size: 8.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for NERX-0.2.1.tar.gz
Algorithm Hash digest
SHA256 13f7dd5fa35c8be0d515560f700f43503f68c19209209730b7d76cf206d8f2e2
MD5 382524412988308ea55b70fe968f87a8
BLAKE2b-256 00235760e5aa822169cce6b56d95e05c897199a4c986deb315f4a23d5e2bc557

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page