Skip to main content

A small seq2seq punctuator tool based on DistilBERT

Project description

Distilbert-punctuator

PyPI version PyPi Downloads License

Introduction

Distilbert-punctuator is a python package provides a bert-based punctuator (fine-tuned model of pretrained huggingface DistilBertForTokenClassification) with following three components:

  • data process: funcs for processing user's data to prepare for training. If user perfer to fine-tune the model with his/her own data.

  • training: training pipeline and doing validation. User can fine-tune his/her own punctuator with the pipeline

  • inference: easy-to-use interface for user to use trained punctuator.

  • If user doesn't want to train a punctuator himself/herself, two pre-fined-tuned model from huggingface model hub

  • model examples in huggingface web page.

    • English model
    • Simplified Chinese model

Installation

  • Installing the package from pypi: pip install distilbert-punctuator for directly usage of punctuator.
  • Installing the package with option to do data processing pip install distilbert-punctuator[data_process].
  • Installing the package with option to train and validate your own model pip install distilbert-punctuator[training]
  • For development and contribution
    • clone the repo
    • make install

Data Process

Component for pre-processing the training data. To use this component, please install as pip install distilbert-punctuator[data_process]

The package is providing a simple pipeline for you to generate NER format training data.

Example

examples/data_sample.py

Train

Component for providing a training pipeline for fine-tuning a pretrained DistilBertForTokenClassification model from huggingface. The latest version has the implementation of R-Drop enhanced training. R-Drop github repo Paper of R-Drop

Example

examples/english_train_sample.py

Training_arguments:

Arguments required for the training pipeline.

  • basic arguments

    • training_corpus(List[List[str]]): list of sequences for training, longest sequence should be no longer than pretrained LM # noqa: E501
    • validation_corpus(List[List[str]]): list of sequences for validation, longest sequence should be no longer than pretrained LM # noqa: E501
    • training_tags(List[List[int]]): tags(int) for training
    • validation_tags(List[List[int]]): tags(int) for validation
    • model_name_or_path(str): name or path of pre-trained model
    • tokenizer_name(str): name of pretrained tokenizer
  • training arguments

    • epoch(int): number of epoch
    • batch_size(int): batch size
    • model_storage_dir(str): fine-tuned model storage path
    • label2id(Dict): the tags label and id mapping
    • early_stop_count(int): after how many epochs to early stop training if valid loss not become smaller. default 3 # noqa: E501
    • gpu_device(int): specific gpu card index, default is the CUDA_VISIBLE_DEVICES from environ
    • warm_up_steps(int): warm up steps.
    • r_drop(bool): whether to train with r-drop
    • r_alpha(int): alpha value for kl divengence in the loss, default is 0
    • plot_steps(int): record training status to tensorboard among how many steps
    • tensorboard_log_dir(Optional[str]): the tensorboard logs output directory, default is "runs"
  • model arguments

    • addtional_model_config(Optional[Dict]): additional configuration for model

You can also train your own NER models with the trainer provided in this repo. The example can be found in notebooks/R-drop NER.ipynb

Evaluation

Validation of fine-tuned model

Example

examples/train_sample.py

Validation_arguments:

  • evaluation_corpus(List[List[str]]): list of sequences for evaluation, longest sequence should be no longer than pretrained LM's max_position_embedding(512)
  • evaluation_tags(List[List[int]]): tags(int) for evaluation (the GT)
  • model_name_or_path(str): name or path of fine-tuned model
  • tokenizer_name(str): name of tokenizer
  • batch_size(int): batch size
  • label2id(Optional[Dict]): label2id. Default one is from model config. Pass in this argument if your model doesn't have a label2id inside config
  • gpu_device(int): specific gpu card index, default is the CUDA_VISIBLE_DEVICES from environ

Inference

Component for providing an inference interface for user to use punctuator.

Architecture

 +----------------------+              (child process)
 |   user application   |             +-------------------+
 +                      + <---------->| punctuator server |
 |   +inference object  |             +-------------------+
 +----------------------+

The punctuator will be deployed in a child process which communicates with main process through pipe connection. Therefore user can initialize an inference object and call its punctuation function when needed. The punctuator will never block the main process unless doing punctuation. There is a graceful shutdown methodology for the punctuator, hence user dosen't need to worry about the shutting-down.

Example

examples/inference_sample.py

Inference_arguments

Arguments required for the inference pipeline.

  • model_name_or_path(str): name or path of pre-trained model
  • tokenizer_name(str): name of pretrained tokenizer
  • tag2punctuator(Dict[str, tuple]): tag to punctuator mapping. dbpunctuator.utils provides two default mappings for English and Chinese
    NORMAL_TOKEN_TAG = "O"
    DEFAULT_ENGLISH_TAG_PUNCTUATOR_MAP = {
        NORMAL_TOKEN_TAG: ("", False),
        "COMMA": (",", False),
        "PERIOD": (".", True),
        "QUESTIONMARK": ("?", True),
        "EXLAMATIONMARK": ("!", True),
    }
    
    DEFAULT_CHINESE_TAG_PUNCTUATOR_MAP = {
        NORMAL_TOKEN_TAG: ("", False),
        "C_COMMA": (",", False),
        "C_PERIOD": ("。", True),
        "C_QUESTIONMARK": ("? ", True),
        "C_EXLAMATIONMARK": ("! ", True),
        "C_DUNHAO": ("、", False),
    }
    
    for own fine-tuned model with different tags, pass in your own mapping
  • tag2id_storage_path(Optional[str]): tag2id storage path. Default one is from model config. Pass in this argument if your model doesn't have a tag2id inside config

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

distilbert-punctuator-0.3.0.tar.gz (26.8 kB view details)

Uploaded Source

Built Distribution

distilbert_punctuator-0.3.0-py3-none-any.whl (34.2 kB view details)

Uploaded Python 3

File details

Details for the file distilbert-punctuator-0.3.0.tar.gz.

File metadata

  • Download URL: distilbert-punctuator-0.3.0.tar.gz
  • Upload date:
  • Size: 26.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.13

File hashes

Hashes for distilbert-punctuator-0.3.0.tar.gz
Algorithm Hash digest
SHA256 81db2e9425e325a90a6fc1c2a0eb166ce282de1b22281457bde125d47aa162f5
MD5 d2e4962f9169811ba8333c1798ac92db
BLAKE2b-256 968e49bc3452262fd9e6f5d85421b37bf9a1f3b9cf1fd3ffafc805dbdcb7ad07

See more details on using hashes here.

File details

Details for the file distilbert_punctuator-0.3.0-py3-none-any.whl.

File metadata

File hashes

Hashes for distilbert_punctuator-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5c97954976c1e07946537c0a26a3b99e1cf7d19503e59af34a788f89e4acd15a
MD5 5a113b2794cca595083c042c0dd14952
BLAKE2b-256 70f6d447e4ea277e48ffe7aa66f885ca8f501f3e5f677a40d3521b4e5c61e1e0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page