Skip to main content

Named Entity Recognition

Project description

Usage Sample ''''''''''''

.. code:: python

import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from nerx import NER, Collator
from model_wrapper.dataset import DictDataset
from model_wrapper import ClassifyModelWrapper

pretrained_path = "nghuyong/ernie-3.0-base-zh"
classes = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'PADDING']
num_classes = len(classes)
    
def f(data):
    return 5 < len(data['tokens']) <= 512 - 2

dataset_dict = load_from_disk('/kaggle/input/peoples-daily-ner-data/peoples_daily_ner')
train_set = dataset_dict['train'].remove_columns(['id']).filter(f, cache_file_name='/kaggle/working/train.cache')
val_set = dataset_dict['validation'].remove_columns(['id']).filter(f, cache_file_name='/kaggle/working/val.cache')
test_set = dataset_dict['test'].remove_columns(['id']).filter(f, cache_file_name='/kaggle/working/test.cache')    
train_set = DictDataset(train_set, 'tokens', 'ner_tags')
val_set = DictDataset(val_set, 'tokens', 'ner_tags')

model = NER(pretrained_path, num_classes=num_classes, num_train_layers=2)
wrapper = ClassifyModelWrapper(model)
tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
history = wrapper.train(train_set, val_set, collate_fn=Collator(tokenizer, num_classes - 1))
wrapper.save_state_dict(mode='best')

def display(tags, text, classes):
    padding_idx = len(classes) - 1
    start_index, start_tag = -1, -1
    for i, tag in enumerate(tags):
        if tag == padding_idx:
            if start_index != -1:
                print(f"{start_index}-{i}", ' ', classes[start_tag].split('-')[1], ' ', ''.join(text[start_index:i]))
            break    
        if 0 < tag:
            if start_index == -1 and 0 < tag:
                start_index, start_tag = i, tag
                continue
                    
            if start_tag != tag - 1 and start_tag != tag:
                print(f"{start_index}-{i}", ' ', classes[start_tag].split('-')[1], ' ', ''.join(text[start_index:i]))
                start_index, start_tag = i, tag 
        else:
            if start_index > -1:
                print(f"{start_index}-{i}", ' ', classes[start_tag].split('-')[1], ' ', ''.join(text[start_index:i]))
                start_index, start_tag = -1, -1

def test(data, model):
    M, N = 50, 30
    text, label = data['tokens'], data['ner_tags']
    tokens = tokenizer.batch_encode_plus([text],
                                    max_length=256,
                                    padding=True,
                                    truncation=True,
                                    return_tensors='pt',
                                    return_token_type_ids=False,
                                    is_split_into_words=True)
    model.eval()
    with torch.inference_mode():
        result = model(tokens)[0]
    print('=' * M, "原文", '=' * M)
    print(''.join(text))
    print('-' * N, "标注",'-' * N)
    display(label, text, classes)
    print('-' * N, "预测",'-' * N)
    display(result, text, classes)    

for i in range(20):
    test(test_set[i], model)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

NERX-0.2.0.tar.gz (8.3 kB view details)

Uploaded Source

File details

Details for the file NERX-0.2.0.tar.gz.

File metadata

  • Download URL: NERX-0.2.0.tar.gz
  • Upload date:
  • Size: 8.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for NERX-0.2.0.tar.gz
Algorithm Hash digest
SHA256 aeecc1cf3890faabc8d0a9d033b904805b699890e78006016a01efcb2aa24e79
MD5 602f55aaf2236c5a559894ae96374b1b
BLAKE2b-256 413c9ddbe57872cdabe8b52eca6e96828c139ab4d35659c12fffdbadcca4b625

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page