Skip to main content

Terry toolkit tkitAutoMask,

Project description

tkitAutoMask

自动构建掩码 加入多种动态掩码合集,上下三角和动态片段,以及默认的概率

-上三角,实现类似从左到右的预测,就是单向注意,用于续写。

  • 片段,连续多个mask,更加适合解决补全。

未来尝试加入 模板预测掩码

pip install tkitAutoMask


from tkitAutoMask import autoMask
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("uer/chinese_roberta_L-2_H-128") 
# dir(tokenizer)
tomask = autoMask(
    # transformer,
    mask_token_id = tokenizer.mask_token_id,          # the token id reserved for masking
    pad_token_id = -100,           # the token id for padding
    mask_prob = 0.05,           # 仅仅是常规的掩码比例 masking probability for masked language modeling
    replace_prob = 0.90,        # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper
    mask_ignore_token_ids = [tokenizer.cls_token_id,tokenizer.eos_token_id]  # other tokens to exclude from masking, include the [cls] and [sep] here
)


# x=torch.ones(5,5)
x = torch.randint(0, 20000, (10, 10))
for i in range(10):
  a,b=tomask(x)
  # a,b
  print(b)

labels:形状为[batch_size, seq_length] ,代表MLM任务的标签,注意这里对于原本未被遮盖的词设置为-100,被遮盖词才会有它们对应的id,和任务设置是反过来的。 例如,原始句子是I want to [MASK] an apple,这里我把单词eat给遮住了输入模型,对应的label设置为[-100, -100, -100, 【eat对应的id】, -100, -100]; 为什么要设置为-100而不是其他数? 因为torch.nn.CrossEntropyLoss默认的ignore_index=-100,也就是说对于标签为100的类别输入不会计算loss。

tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  6238,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  7321,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 11728,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  3641,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100, 14913,  -100,  -100,  -100,  -100],
        [ -100,  8332,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 11952,  -100],
        [ -100,  -100,  -100,  -100, 12768,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,    77],
        [ -100,  -100, 16031,  -100,  -100,  -100,  -100,  -100,  -100,  -100]])
tensor([[ -100,  -100,  1312,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  7849],
        [ 9007,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  1822],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 17593],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100, 13736,  -100,  -100],
        [ -100,  -100,  -100, 16620,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100, 18083,  -100,  -100],
        [ -100,  -100,  -100, 15338,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100, 12984,  -100,  -100,  -100,  -100,  -100,  -100]])
tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  4867],
        [ -100, 15820,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ 9007,  1684,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ 4373, 13507,  -100,  -100,  -100,  -100,  -100, 19849,  -100,  -100],
        [19143, 19690, 16235,  -100,  -100, 14913,  -100,  -100,  -100,  -100],
        [18837,  8332, 13231, 16312,  -100,  -100,  8517,  -100,  -100,  -100],
        [ 1567,   928,   268, 16620, 16337,  2932,  -100,  -100,  -100,  -100],
        [ 9537,  1362, 16203, 10865, 12768, 10351,  -100,  -100,  -100,  4658],
        [12488, 17234,  4130, 15338,  4766,  6458, 15765,  -100,  -100,  -100],
        [19972,   457, 16031, 12984, 14118,  4127, 13889, 13456,  -100,  -100]])
tensor([[ 2649,  3837,  1312, 12421, 15558,  -100,  -100,  -100,  -100,  -100],
        [ -100, 15820,  2654,  3647, 13259,  6178,  -100,  -100,  -100,  7849],
        [ 9007,  -100, 17864,   360,  4748, 10698,  3624,  -100,  -100,  -100],
        [ -100, 13507,  -100,  5198,  4845, 18414,  3641, 19849,  -100,  -100],
        [ -100,  -100,  -100, 17247,  7694, 14913,  4696,  3476,  7539,  -100],
        [ -100,  -100,  -100,  -100,  -100,  5739,  8517, 13736,  8122, 16682],
        [ -100,  -100,  -100,  -100, 16337,  -100, 12610,  6181, 11952,  4669],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100, 18083, 14632,  4658],
        [ -100,  -100,  -100, 15338,  -100,  -100,  -100,  -100, 10558,    77],
        [ -100,  -100, 16031,  -100,  -100,  -100,  -100,  -100,  -100, 12816]])
tensor([[ -100,  -100,  -100,  -100, 15558,  -100,  -100,  -100,  -100,  -100],
        [ -100, 15820,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100, 17864,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  4845,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  7694,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100, 16312,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100, 12610,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  4658],
        [12488,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100, 13456,  -100,  -100]])
tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  4867],
        [ -100,  -100,  2654,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ 9007,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  3641,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100, 14913,  -100,  -100,  -100,  -100],
        [ -100,  -100, 13231,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,   268,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 14632,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100, 15765,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100, 14118,  -100,  -100,  -100,  -100,  -100]])
tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  7519,  -100,  -100,  -100],
        [15670,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  1684,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  1822],
        [ -100, 19690,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100, 13231,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  4669],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  4658],
        [12488,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100, 16031,  -100,  -100,  -100,  -100,  -100,  -100,  -100]])
tensor([[ 2649,  3837,  1312,  -100,  -100,   976,  -100,  -100,  -100,  -100],
        [ -100, 15820,  2654,  3647,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100, 17864,   360,  4748,  -100,  3624,  -100,  -100,  -100],
        [ 4373,  -100,  -100,  5198,  4845, 18414,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  7694, 14913,  4696,  -100,  7539,  -100],
        [ -100,  -100,  -100,  -100,  -100,  5739,  8517, 13736,  -100,  -100],
        [ -100,   928,  -100,  -100,  -100,  -100, 12610,  6181, 11952,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100, 18083, 14632,  4658],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100, 19026, 10558,    77],
        [ -100,   457,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 12816]])
tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  7519,  -100,  -100,  -100],
        [ -100,  -100,  2654,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  4748,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  7381,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  7539,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  8122,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100, 12610,  -100,  -100,  -100],
        [ -100,  -100, 16203,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  6458,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  4127,  -100,  -100,  -100,  -100]])
tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  6238,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  7849],
        [ -100,  -100,  -100,  -100,  4748,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100, 18414,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100, 14913,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100, 16312,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,   928,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  -100, 19242,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  6458,  -100,  -100,  -100,  -100],
        [ -100,  -100,  -100,  -100,  -100,  4127,  -100,  -100,  -100,  -100]])

其他测试

https://colab.research.google.com/drive/1CvkoJ1pZQDRWGPA-5IzJufvocBM-RVT2#scrollTo=UwkociF5ZF-d

https://colab.research.google.com/drive/1kNHD0I0wH3WBpJXPdgZqs0MZTRnGD-ok#scrollTo=6M1ZXRsuxZAa

unilm_mask注意力写法 https://colab.research.google.com/drive/11IDalP2xNYWzF4gIz6T3yTjp53UqzkOe#scrollTo=gFeycxpykrCx

详细参考

dev.md

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tkitAutoMask-0.0.0.316483908.tar.gz (16.5 kB view details)

Uploaded Source

Built Distribution

tkitAutoMask-0.0.0.316483908-py3-none-any.whl (14.3 kB view details)

Uploaded Python 3

File details

Details for the file tkitAutoMask-0.0.0.316483908.tar.gz.

File metadata

  • Download URL: tkitAutoMask-0.0.0.316483908.tar.gz
  • Upload date:
  • Size: 16.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.9 tqdm/4.63.1 importlib-metadata/4.11.3 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.8.12

File hashes

Hashes for tkitAutoMask-0.0.0.316483908.tar.gz
Algorithm Hash digest
SHA256 d0c0562e9f655d46c2b9aa18d7398c9e4aa184b82ae061a53d3f8cc2a8dd9279
MD5 d40147512bc7a24e3f02d44fb1c3a014
BLAKE2b-256 143b1fd6d9c6a0ee1df726b3bb1d124039299227f64c3eceb9ca35a0c142f663

See more details on using hashes here.

File details

Details for the file tkitAutoMask-0.0.0.316483908-py3-none-any.whl.

File metadata

  • Download URL: tkitAutoMask-0.0.0.316483908-py3-none-any.whl
  • Upload date:
  • Size: 14.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.9 tqdm/4.63.1 importlib-metadata/4.11.3 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.8.12

File hashes

Hashes for tkitAutoMask-0.0.0.316483908-py3-none-any.whl
Algorithm Hash digest
SHA256 61b0ccd2f914444cb052f5239a5be3e42a3b25e134a0a3f43a576719abae7ff9
MD5 73d97f120dbc8d102082917341506ca6
BLAKE2b-256 dc3c0b7e419bce979dffc28bebbde7c445b7115c090ceee6338996b7a46388d7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page