bert_pretty is a text encoder and result decoder
Project description
bert_pretty is a text encoder and result decoder
# -*- coding:utf-8 -*-
'''
bert input_instance encode and result decode
https://github.com/ssbuild/bert_pretty.git
'''
import numpy as np
#FullTokenizer is official and you can use your tokenization .
from bert_pretty import FullTokenizer,\
text_feature, \
text_feature_char_level,\
text_feature_word_level,\
text_feature_char_level_input_ids_mask, \
text_feature_word_level_input_ids_mask, \
text_feature_char_level_input_ids_segment, \
text_feature_word_level_input_ids_segment, \
seqs_padding,rematch
from bert_pretty.ner import load_label_bioes,load_label_bio,load_labels as ner_load_labels
from bert_pretty.ner import ner_crf_decoding,\
ner_pointer_decoding,\
ner_pointer_decoding_with_mapping,\
ner_pointer_double_decoding,ner_pointer_double_decoding_with_mapping
from bert_pretty.cls import cls_softmax_decoding,cls_sigmoid_decoding,load_labels as cls_load_labels
tokenizer = FullTokenizer(vocab_file=r'F:\pretrain\chinese_L-12_H-768_A-12\vocab.txt',do_lower_case=True)
text_list = ["你是谁123aa\ta嘂a","嘂adasd"]
def test():
maxlen = 512
do_lower_case = tokenizer.basic_tokenizer.do_lower_case
inputs = [['[CLS]'] + tokenizer.tokenize(text)[:maxlen - 2] + ['[SEP]'] for text in text_list]
mapping = [rematch(text, tokens, do_lower_case) for text, tokens in zip(text_list, inputs)]
inputs = [tokenizer.convert_tokens_to_ids(input) for input in inputs]
input_mask = [[1] * len(input) for input in inputs]
input_segment = [[0] * len(input) for input in inputs]
input_ids = seqs_padding(inputs)
input_mask = seqs_padding(input_mask)
input_segment = seqs_padding(input_segment)
input_ids = np.asarray(input_ids, dtype=np.int32)
input_mask = np.asarray(input_mask, dtype=np.int32)
input_segment = np.asarray(input_segment, dtype=np.int32)
print('input_ids\n', input_ids)
print('mapping\n',mapping)
print('input_mask\n',input_mask)
print('input_segment\n',input_segment)
print('\n\n')
def test_charlevel():
do_lower_case = tokenizer.basic_tokenizer.do_lower_case
maxlen = 512
if do_lower_case:
inputs = [['[CLS]'] + tokenizer.tokenize(text.lower())[:maxlen - 2] + ['[SEP]'] for text in text_list]
else:
inputs = [['[CLS]'] + tokenizer.tokenize(text)[:maxlen - 2] + ['[SEP]'] for text in text_list]
inputs = [tokenizer.convert_tokens_to_ids(input) for input in inputs]
input_mask = [[1] * len(input) for input in inputs]
input_segment = [[0] * len(input) for input in inputs]
input_ids = seqs_padding(inputs)
input_mask = seqs_padding(input_mask)
input_segment = seqs_padding(input_segment)
input_ids = np.asarray(input_ids, dtype=np.int32)
input_mask = np.asarray(input_mask, dtype=np.int32)
input_segment = np.asarray(input_segment, dtype=np.int32)
print('input_ids\n', input_ids)
print('input_mask\n',input_mask)
print('input_segment\n',input_segment)
print('\n\n')
# labels = ['标签1','标签2']
# print(cls.load_labels(labels))
#
# print(ner.load_label_bio(labels))
'''
# def ner_crf_decoding(batch_text, id2label, batch_logits, trans=None,batch_mapping=None,with_dict=True):
ner crf decode 解析crf序列 or 解析 已经解析过的crf序列
batch_text input_instance list ,
id2label 标签 list or dict
batch_logits 为bert 预测结果 logits_all (batch,seq_len,num_tags) or (batch,seq_len)
trans 是否启用trans预测 , 2D
batch_mapping 映射序列
'''
'''
def ner_pointer_decoding(batch_text, id2label, batch_logits, threshold=1e-8,coordinates_minus=False,with_dict=True)
batch_text text list ,
id2label 标签 list or dict
batch_logits (batch,num_labels,seq_len,seq_len)
threshold 阈值
coordinates_minus
'''
'''
def ner_pointer_decoding_with_mapping(batch_text, id2label, batch_logits, batch_mapping,threshold=1e-8,coordinates_minus=False,with_dict=True)
batch_text text list ,
id2label 标签 list or dict
batch_logits (batch,num_labels,seq_len,seq_len)
threshold 阈值
coordinates_minus
'''
'''
cls_softmax_decoding(batch_text, id2label, batch_logits,threshold=None)
batch_text 文本list ,
id2label 标签 list or dict
batch_logits (batch,num_classes)
threshold 阈值
'''
'''
cls_sigmoid_decoding(batch_text, id2label, batch_logits,threshold=0.5)
batch_text 文本list ,
id2label 标签 list or dict
batch_logits (batch,num_classes)
threshold 阈值
'''
def test_cls_decode():
num_label =3
np.random.seed(123)
batch_logits = np.random.rand(2,num_label)
result = cls_softmax_decoding(text_list,['标签1','标签2','标签3'],batch_logits,threshold=None)
print(result)
batch_logits = np.random.rand(2,num_label)
print(batch_logits)
result = cls_sigmoid_decoding(text_list,['标签1','标签2','标签3'],batch_logits,threshold=0.5)
print(result)
if __name__ == '__main__':
test()
test_charlevel()
test_cls_decode()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
No source distribution files available for this release.See tutorial on generating distribution archives.
Built Distribution
File details
Details for the file bert_pretty-0.1.0.post0-py3-none-any.whl
.
File metadata
- Download URL: bert_pretty-0.1.0.post0-py3-none-any.whl
- Upload date:
- Size: 31.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.5.0 importlib_metadata/4.8.2 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.64.0 CPython/3.8.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 52ff286e28f17f487c3486cced6c6a8db49a901f289370a18ee50cda52aa4b00 |
|
MD5 | 9008d9f1f3813bf27c6edbb88c574466 |
|
BLAKE2b-256 | c19c9acd5b3443ae4079de0c83bdfe79ee04e2fa9b82634f9092ff9fb85092e6 |