Terry toolkit tkitAutoMask,
Project description
tkitAutoMask
自动构建掩码 加入多种动态掩码合集,上下三角和动态片段,以及默认的概率
-上三角,实现类似从左到右的预测,就是单向注意,用于续写。
- 片段,连续多个mask,更加适合解决补全。
未来尝试加入 模板预测掩码
pip install tkitAutoMask
from tkitAutoMask import autoMask
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("uer/chinese_roberta_L-2_H-128")
# dir(tokenizer)
tomask = autoMask(
# transformer,
mask_token_id = tokenizer.mask_token_id, # the token id reserved for masking
pad_token_id = -100, # the token id for padding
mask_prob = 0.05, # 仅仅是常规的掩码比例 masking probability for masked language modeling
replace_prob = 0.90, # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper
mask_ignore_token_ids = [tokenizer.cls_token_id,tokenizer.eos_token_id] # other tokens to exclude from masking, include the [cls] and [sep] here
)
# x=torch.ones(5,5)
x = torch.randint(0, 20000, (10, 10))
for i in range(10):
a,b=tomask(x)
# a,b
print(b)
labels:形状为[batch_size, seq_length] ,代表MLM任务的标签,注意这里对于原本未被遮盖的词设置为-100,被遮盖词才会有它们对应的id,和任务设置是反过来的。 例如,原始句子是I want to [MASK] an apple,这里我把单词eat给遮住了输入模型,对应的label设置为[-100, -100, -100, 【eat对应的id】, -100, -100]; 为什么要设置为-100而不是其他数? 因为torch.nn.CrossEntropyLoss默认的ignore_index=-100,也就是说对于标签为100的类别输入不会计算loss。
tensor([[ -100, -100, -100, -100, -100, -100, -100, -100, 6238, -100],
[ -100, -100, -100, -100, -100, -100, -100, 7321, -100, -100],
[ -100, -100, -100, -100, -100, -100, -100, -100, 11728, -100],
[ -100, -100, -100, -100, -100, -100, 3641, -100, -100, -100],
[ -100, -100, -100, -100, -100, 14913, -100, -100, -100, -100],
[ -100, 8332, -100, -100, -100, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, -100, -100, -100, 11952, -100],
[ -100, -100, -100, -100, 12768, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, -100, -100, -100, -100, 77],
[ -100, -100, 16031, -100, -100, -100, -100, -100, -100, -100]])
tensor([[ -100, -100, 1312, -100, -100, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, -100, -100, -100, -100, 7849],
[ 9007, -100, -100, -100, -100, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, -100, -100, -100, -100, 1822],
[ -100, -100, -100, -100, -100, -100, -100, -100, -100, 17593],
[ -100, -100, -100, -100, -100, -100, -100, 13736, -100, -100],
[ -100, -100, -100, 16620, -100, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, -100, -100, 18083, -100, -100],
[ -100, -100, -100, 15338, -100, -100, -100, -100, -100, -100],
[ -100, -100, -100, 12984, -100, -100, -100, -100, -100, -100]])
tensor([[ -100, -100, -100, -100, -100, -100, -100, -100, -100, 4867],
[ -100, 15820, -100, -100, -100, -100, -100, -100, -100, -100],
[ 9007, 1684, -100, -100, -100, -100, -100, -100, -100, -100],
[ 4373, 13507, -100, -100, -100, -100, -100, 19849, -100, -100],
[19143, 19690, 16235, -100, -100, 14913, -100, -100, -100, -100],
[18837, 8332, 13231, 16312, -100, -100, 8517, -100, -100, -100],
[ 1567, 928, 268, 16620, 16337, 2932, -100, -100, -100, -100],
[ 9537, 1362, 16203, 10865, 12768, 10351, -100, -100, -100, 4658],
[12488, 17234, 4130, 15338, 4766, 6458, 15765, -100, -100, -100],
[19972, 457, 16031, 12984, 14118, 4127, 13889, 13456, -100, -100]])
tensor([[ 2649, 3837, 1312, 12421, 15558, -100, -100, -100, -100, -100],
[ -100, 15820, 2654, 3647, 13259, 6178, -100, -100, -100, 7849],
[ 9007, -100, 17864, 360, 4748, 10698, 3624, -100, -100, -100],
[ -100, 13507, -100, 5198, 4845, 18414, 3641, 19849, -100, -100],
[ -100, -100, -100, 17247, 7694, 14913, 4696, 3476, 7539, -100],
[ -100, -100, -100, -100, -100, 5739, 8517, 13736, 8122, 16682],
[ -100, -100, -100, -100, 16337, -100, 12610, 6181, 11952, 4669],
[ -100, -100, -100, -100, -100, -100, -100, 18083, 14632, 4658],
[ -100, -100, -100, 15338, -100, -100, -100, -100, 10558, 77],
[ -100, -100, 16031, -100, -100, -100, -100, -100, -100, 12816]])
tensor([[ -100, -100, -100, -100, 15558, -100, -100, -100, -100, -100],
[ -100, 15820, -100, -100, -100, -100, -100, -100, -100, -100],
[ -100, -100, 17864, -100, -100, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, 4845, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, 7694, -100, -100, -100, -100, -100],
[ -100, -100, -100, 16312, -100, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, -100, 12610, -100, -100, -100],
[ -100, -100, -100, -100, -100, -100, -100, -100, -100, 4658],
[12488, -100, -100, -100, -100, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, -100, -100, 13456, -100, -100]])
tensor([[ -100, -100, -100, -100, -100, -100, -100, -100, -100, 4867],
[ -100, -100, 2654, -100, -100, -100, -100, -100, -100, -100],
[ 9007, -100, -100, -100, -100, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, -100, 3641, -100, -100, -100],
[ -100, -100, -100, -100, -100, 14913, -100, -100, -100, -100],
[ -100, -100, 13231, -100, -100, -100, -100, -100, -100, -100],
[ -100, -100, 268, -100, -100, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, -100, -100, -100, 14632, -100],
[ -100, -100, -100, -100, -100, -100, 15765, -100, -100, -100],
[ -100, -100, -100, -100, 14118, -100, -100, -100, -100, -100]])
tensor([[ -100, -100, -100, -100, -100, -100, 7519, -100, -100, -100],
[15670, -100, -100, -100, -100, -100, -100, -100, -100, -100],
[ -100, 1684, -100, -100, -100, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, -100, -100, -100, -100, 1822],
[ -100, 19690, -100, -100, -100, -100, -100, -100, -100, -100],
[ -100, -100, 13231, -100, -100, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, -100, -100, -100, -100, 4669],
[ -100, -100, -100, -100, -100, -100, -100, -100, -100, 4658],
[12488, -100, -100, -100, -100, -100, -100, -100, -100, -100],
[ -100, -100, 16031, -100, -100, -100, -100, -100, -100, -100]])
tensor([[ 2649, 3837, 1312, -100, -100, 976, -100, -100, -100, -100],
[ -100, 15820, 2654, 3647, -100, -100, -100, -100, -100, -100],
[ -100, -100, 17864, 360, 4748, -100, 3624, -100, -100, -100],
[ 4373, -100, -100, 5198, 4845, 18414, -100, -100, -100, -100],
[ -100, -100, -100, -100, 7694, 14913, 4696, -100, 7539, -100],
[ -100, -100, -100, -100, -100, 5739, 8517, 13736, -100, -100],
[ -100, 928, -100, -100, -100, -100, 12610, 6181, 11952, -100],
[ -100, -100, -100, -100, -100, -100, -100, 18083, 14632, 4658],
[ -100, -100, -100, -100, -100, -100, -100, 19026, 10558, 77],
[ -100, 457, -100, -100, -100, -100, -100, -100, -100, 12816]])
tensor([[ -100, -100, -100, -100, -100, -100, 7519, -100, -100, -100],
[ -100, -100, 2654, -100, -100, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, 4748, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, -100, -100, -100, 7381, -100],
[ -100, -100, -100, -100, -100, -100, -100, -100, 7539, -100],
[ -100, -100, -100, -100, -100, -100, -100, -100, 8122, -100],
[ -100, -100, -100, -100, -100, -100, 12610, -100, -100, -100],
[ -100, -100, 16203, -100, -100, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, 6458, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, 4127, -100, -100, -100, -100]])
tensor([[ -100, -100, -100, -100, -100, -100, -100, -100, 6238, -100],
[ -100, -100, -100, -100, -100, -100, -100, -100, -100, 7849],
[ -100, -100, -100, -100, 4748, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, 18414, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, 14913, -100, -100, -100, -100],
[ -100, -100, -100, 16312, -100, -100, -100, -100, -100, -100],
[ -100, 928, -100, -100, -100, -100, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, -100, 19242, -100, -100, -100],
[ -100, -100, -100, -100, -100, 6458, -100, -100, -100, -100],
[ -100, -100, -100, -100, -100, 4127, -100, -100, -100, -100]])
其他测试
https://colab.research.google.com/drive/1CvkoJ1pZQDRWGPA-5IzJufvocBM-RVT2#scrollTo=UwkociF5ZF-d
https://colab.research.google.com/drive/1kNHD0I0wH3WBpJXPdgZqs0MZTRnGD-ok#scrollTo=6M1ZXRsuxZAa
unilm_mask注意力写法 https://colab.research.google.com/drive/11IDalP2xNYWzF4gIz6T3yTjp53UqzkOe#scrollTo=gFeycxpykrCx
详细参考
dev.md
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for tkitAutoMask-0.0.0.316483908.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | d0c0562e9f655d46c2b9aa18d7398c9e4aa184b82ae061a53d3f8cc2a8dd9279 |
|
MD5 | d40147512bc7a24e3f02d44fb1c3a014 |
|
BLAKE2b-256 | 143b1fd6d9c6a0ee1df726b3bb1d124039299227f64c3eceb9ca35a0c142f663 |
Hashes for tkitAutoMask-0.0.0.316483908-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 61b0ccd2f914444cb052f5239a5be3e42a3b25e134a0a3f43a576719abae7ff9 |
|
MD5 | 73d97f120dbc8d102082917341506ca6 |
|
BLAKE2b-256 | dc3c0b7e419bce979dffc28bebbde7c445b7115c090ceee6338996b7a46388d7 |