Skip to main content

Named Entity Recognition

Project description

Usage Sample ''''''''''''

.. code:: python

    import torch
    from torch.utils.data import Dataset
    from transformers import AutoTokenizer
    from nerx import NER, Collator
    from model_wrapper import ClassifyModelWrapper

    pretrained_path = "nghuyong/ernie-3.0-base-zh"

    def f(data):
            return 5 < len(data['tokens']) <= 512 - 2

    class PairDataset(Dataset):

            def __init__(self, dataset):
                    self.dataset = dataset

            def __getitem__(self, index):
                    data = self.dataset[index]
                    return data['tokens'], data['ner_tags']

            def __len__(self):
                    return len(self.dataset)

    dataset_dict = load_from_disk('/kaggle/input/peoples-daily-ner-data/peoples_daily_ner')
    train_set = dataset_dict['train'].remove_columns(['id']).filter(f, cache_file_name='/kaggle/working/train.cache')
    val_set = dataset_dict['validation'].remove_columns(['id']).filter(f, cache_file_name='/kaggle/working/val.cache')
    test_set = dataset_dict['test'].remove_columns(['id']).filter(f, cache_file_name='/kaggle/working/test.cache')

    train_set = PairDataset(train_set)
    val_set = PairDataset(val_set)
    
    model = NER(pretrained_path, num_classes=8, num_train_layers=2)
    wrapper = ClassifyModelWrapper(model)
    tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
    history = wrapper.train(train_set, val_set, collate_fn=Collator(tokenizer, label_padding_id=7))
    wrapper.save_state_dict(mode='best')

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

NERX-0.1.1.tar.gz (7.5 kB view details)

Uploaded Source

File details

Details for the file NERX-0.1.1.tar.gz.

File metadata

  • Download URL: NERX-0.1.1.tar.gz
  • Upload date:
  • Size: 7.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for NERX-0.1.1.tar.gz
Algorithm Hash digest
SHA256 4c265388e8e26adc053e5f51726596635937e47e6fa669fed18ca9879f1f01c9
MD5 485e933f4ba59f0831225c13c3a2c77e
BLAKE2b-256 bce5d9152e72d3a6e9aef00c47a56732f3a1cd3c226a41221b1eb19c65c8a524

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page