Skip to main content

No project description provided

Project description

中文標點符號標注

訓練資料集: p208p2002/ZH-Wiki-Punctuation-Restore-Dataset

共計支援6種標點符號: , 、 。 ? ! ;

安裝

# pip install torch pytorch-lightning
pip install zhpr

使用

from zhpr.predict import DocumentDataset,merge_stride,decode_pred
from transformers import AutoModelForTokenClassification,AutoTokenizer
from torch.utils.data import DataLoader

def predict_step(batch,model,tokenizer):
        batch_out = []
        batch_input_ids = batch

        encodings = {'input_ids': batch_input_ids}
        output = model(**encodings)

        predicted_token_class_id_batch = output['logits'].argmax(-1)
        for predicted_token_class_ids, input_ids in zip(predicted_token_class_id_batch, batch_input_ids):
            out=[]
            tokens = tokenizer.convert_ids_to_tokens(input_ids)
            
            # compute the pad start in input_ids
            # and also truncate the predict
            # print(tokenizer.decode(batch_input_ids))
            input_ids = input_ids.tolist()
            try:
                input_id_pad_start = input_ids.index(tokenizer.pad_token_id)
            except:
                input_id_pad_start = len(input_ids)
            input_ids = input_ids[:input_id_pad_start]
            tokens = tokens[:input_id_pad_start]
    
            # predicted_token_class_ids
            predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids]
            predicted_tokens_classes = predicted_tokens_classes[:input_id_pad_start]

            for token,ner in zip(tokens,predicted_tokens_classes):
                out.append((token,ner))
            batch_out.append(out)
        return batch_out

if __name__ == "__main__":
    window_size = 256
    step = 200
    text = "維基百科是維基媒體基金會運營的一個多語言的百科全書目前是全球網路上最大且最受大眾歡迎的參考工具書名列全球二十大最受歡迎的網站特點是自由內容自由編輯與自由著作權"
    dataset = DocumentDataset(text,window_size=window_size,step=step)
    dataloader = DataLoader(dataset=dataset,shuffle=False,batch_size=5)

    model_name = 'p208p2002/zh-wiki-punctuation-restore'
    model = AutoModelForTokenClassification.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    new_tokens = [chr(x) for x in range(ord('A'), ord('Z') + 1)]
    new_tokens.append(" ")
    tokenizer.add_special_tokens({'additional_special_tokens':new_tokens})
    model.resize_token_embeddings(len(tokenizer))
    
    model_pred_out = []
    for batch in dataloader:
        batch_out = predict_step(batch,model,tokenizer)
        for out in batch_out:
            model_pred_out.append(out)
        
    merge_pred_result = merge_stride(model_pred_out,step)
    merge_pred_result_deocde = decode_pred(merge_pred_result)
    merge_pred_result_deocde = ''.join(merge_pred_result_deocde)
    print(merge_pred_result_deocde)
維基百科是維基媒體基金會運營的一個多語言的百科全書,目前是全球網路上最大且最受大眾歡迎的參考工具書,名列全球二十大最受歡迎的網站,特點是自由內容、自由編輯與自由著作權。

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

g_zhpr-0.1.3.tar.gz (4.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

g_zhpr-0.1.3-py3-none-any.whl (5.0 kB view details)

Uploaded Python 3

File details

Details for the file g_zhpr-0.1.3.tar.gz.

File metadata

  • Download URL: g_zhpr-0.1.3.tar.gz
  • Upload date:
  • Size: 4.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for g_zhpr-0.1.3.tar.gz
Algorithm Hash digest
SHA256 576c75079d03da1d4157cc734cbb8795be023107297d3c5c8e884ea1de9765a6
MD5 ee6b15323ed88c497e10e7b7d086afda
BLAKE2b-256 7ecb7e39ab4937fae3e50bd3f8f18bae70709bd7c1708ec0cb14881618d2903b

See more details on using hashes here.

File details

Details for the file g_zhpr-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: g_zhpr-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 5.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for g_zhpr-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 4cc36c409b647c4d6cb85877b17d78d66f4c6ade60f04011c1b9bc2d8b75ad5a
MD5 b0e3c7e938e850e0e4e414e244aa89c9
BLAKE2b-256 fad963b8fc205302ae88faf4bfc9b8871cccf468bd3388cd1aec5266960a23d9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page