Transformers library for KoBERT, DistilKoBERT
Project description
KoBERT-Transformers
KoBERT & DistilKoBERT on ๐ค Huggingface Transformers ๐ค
KoBERT ๋ชจ๋ธ์ ๊ณต์ ๋ ํฌ์ ๊ฒ๊ณผ ๋์ผํฉ๋๋ค. ๋ณธ ๋ ํฌ๋ Huggingface tokenizer์ ๋ชจ๋ API๋ฅผ ์ง์ํ๊ธฐ ์ํด์ ์ ์๋์์ต๋๋ค.
๐จ ์ค์! ๐จ
๐ TL;DR
transformers๋v3.0์ด์์ ๋ฐ๋์ ์ค์น!tokenizer๋ ๋ณธ ๋ ํฌ์kobert_transformers/tokenization_kobert.py๋ฅผ ์ฌ์ฉ!
1. Tokenizer ํธํ
Huggingface Transformers๊ฐ v2.9.0๋ถํฐ tokenization ๊ด๋ จ API๊ฐ ์ผ๋ถ ๋ณ๊ฒฝ๋์์ต๋๋ค. ์ด์ ๋ง์ถฐ ๊ธฐ์กด์ tokenization_kobert.py๋ฅผ ์์ ๋ฒ์ ์ ๋ง๊ฒ ์์ ํ์์ต๋๋ค.
2. Embedding์ padding_idx ์ด์
์ด์ ๋ถํฐ BertModel์ BertEmbeddings์์ padding_idx=0์ผ๋ก Hard-coding๋์ด ์์์ต๋๋ค. (์๋ ์ฝ๋ ์ฐธ๊ณ )
class BertEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
๊ทธ๋ฌ๋ Sentencepiece์ ๊ฒฝ์ฐ ๊ธฐ๋ณธ๊ฐ์ผ๋ก pad_token_id=1, unk_token_id=0์ผ๋ก ์ค์ ์ด ๋์ด ์๊ณ (์ด๋ KoBERT๋ ๋์ผ), ์ด๋ฅผ ๊ทธ๋๋ก ์ฌ์ฉํ๋ BertModel์ ๊ฒฝ์ฐ ์์น ์์ ๊ฒฐ๊ณผ๋ฅผ ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค.
Huggingface์์๋ ์ต๊ทผ์ ํด๋น ์ด์๋ฅผ ์ธ์งํ์ฌ ์ด๋ฅผ ์์ ํ์ฌ v2.9.0์ ๋ฐ์ํ์์ต๋๋ค. (๊ด๋ จ PR #3793) config์ pad_token_id=1 ์ ์ถ๊ฐ ๊ฐ๋ฅํ์ฌ ์ด๋ฅผ ํด๊ฒฐํ ์ ์๊ฒ ํ์์ต๋๋ค.
class BertEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
๊ทธ๋ฌ๋ v.2.9.0์์ DistilBERT, ALBERT ๋ฑ์๋ ์ด ์ด์๊ฐ ํด๊ฒฐ๋์ง ์์ ์ง์ PR์ ์ฌ๋ ค ์ฒ๋ฆฌํ์๊ณ (๊ด๋ จ PR #3965), v2.9.1์ ์ต์ข
์ ์ผ๋ก ๋ฐ์๋์ด ๋ฐฐํฌ๋์์ต๋๋ค.
์๋๋ ์ด์ ๊ณผ ํ์ฌ ๋ฒ์ ์ ์ฐจ์ด์ ์ ๋ณด์ฌ์ฃผ๋ ์ฝ๋์ ๋๋ค.
# Transformers v2.7.0
>>> from transformers import BertModel, DistilBertModel
>>> model = BertModel.from_pretrained("monologg/kobert")
>>> model.embeddings.word_embeddings
Embedding(8002, 768, padding_idx=0)
>>> model = DistilBertModel.from_pretrained("monologg/distilkobert")
>>> model.embeddings.word_embeddings
Embedding(8002, 768, padding_idx=0)
### Transformers v2.9.1
>>> from transformers import BertModel, DistilBertModel
>>> model = BertModel.from_pretrained("monologg/kobert")
>>> model.embeddings.word_embeddings
Embedding(8002, 768, padding_idx=1)
>>> model = DistilBertModel.from_pretrained("monologg/distilkobert")
>>> model.embeddings.word_embeddings
Embedding(8002, 768, padding_idx=1)
KoBERT / DistilKoBERT on ๐ค Transformers ๐ค
Dependencies
- torch>=1.1.0
- transformers>=3,<5
How to Use
>>> from transformers import BertModel, DistilBertModel
>>> bert_model = BertModel.from_pretrained('monologg/kobert')
>>> distilbert_model = DistilBertModel.from_pretrained('monologg/distilkobert')
Tokenizer๋ฅผ ์ฌ์ฉํ๋ ค๋ฉด, kobert_transformers/tokenization_kobert.py ํ์ผ์ ๋ณต์ฌํ ํ, KoBertTokenizer๋ฅผ ์ํฌํธํ๋ฉด ๋ฉ๋๋ค.
- KoBERT์ DistilKoBERT ๋ชจ๋ ๋์ผํ ํ ํฌ๋์ด์ ๋ฅผ ์ฌ์ฉํฉ๋๋ค.
- ๊ธฐ์กด KoBERT์ ๊ฒฝ์ฐ Special Token์ด ์ ๋๋ก ๋ถ๋ฆฌ๋์ง ์๋ ์ด์๊ฐ ์์ด์ ํด๋น ๋ถ๋ถ์ ์์ ํ์ฌ ๋ฐ์ํ์์ต๋๋ค. (Issue link)
>>> from tokenization_kobert import KoBertTokenizer
>>> tokenizer = KoBertTokenizer.from_pretrained('monologg/kobert') # monologg/distilkobert๋ ๋์ผ
>>> tokenizer.tokenize("[CLS] ํ๊ตญ์ด ๋ชจ๋ธ์ ๊ณต์ ํฉ๋๋ค. [SEP]")
>>> ['[CLS]', 'โํ๊ตญ', '์ด', 'โ๋ชจ๋ธ', '์', 'โ๊ณต์ ', 'ํฉ๋๋ค', '.', '[SEP]']
>>> tokenizer.convert_tokens_to_ids(['[CLS]', 'โํ๊ตญ', '์ด', 'โ๋ชจ๋ธ', '์', 'โ๊ณต์ ', 'ํฉ๋๋ค', '.', '[SEP]'])
>>> [2, 4958, 6855, 2046, 7088, 1050, 7843, 54, 3]
Kobert-Transformers (Pip library)
tokenization_kobert.py๋ฅผ ๋ฉํํ ํ์ด์ฌ ๋ผ์ด๋ธ๋ฌ๋ฆฌ- KoBERT, DistilKoBERT๋ฅผ Huggingface Transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ ํํ๋ก ์ ๊ณต
v0.5.1์ด์๋ถํฐ๋transformers v3.0์ด์์ผ๋ก ๊ธฐ๋ณธ ์ค์นํฉ๋๋ค. (transformers v4.0๊น์ง๋ ์ด์ ์์ด ์ฌ์ฉ ๊ฐ๋ฅ)
Install Kobert-Transformers
pip3 install kobert-transformers
How to Use
>>> import torch
>>> from kobert_transformers import get_kobert_model, get_distilkobert_model
>>> model = get_kobert_model()
>>> model.eval()
>>> input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
>>> attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
>>> token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
>>> sequence_output, pooled_output = model(input_ids, attention_mask, token_type_ids)
>>> sequence_output[0]
tensor([[-0.2461, 0.2428, 0.2590, ..., -0.4861, -0.0731, 0.0756],
[-0.2478, 0.2420, 0.2552, ..., -0.4877, -0.0727, 0.0754],
[-0.2472, 0.2420, 0.2561, ..., -0.4874, -0.0733, 0.0765]],
grad_fn=<SelectBackward>)
>>> from kobert_transformers import get_tokenizer
>>> tokenizer = get_tokenizer()
>>> tokenizer.tokenize("[CLS] ํ๊ตญ์ด ๋ชจ๋ธ์ ๊ณต์ ํฉ๋๋ค. [SEP]")
['[CLS]', 'โํ๊ตญ', '์ด', 'โ๋ชจ๋ธ', '์', 'โ๊ณต์ ', 'ํฉ๋๋ค', '.', '[SEP]']
>>> tokenizer.convert_tokens_to_ids(['[CLS]', 'โํ๊ตญ', '์ด', 'โ๋ชจ๋ธ', '์', 'โ๊ณต์ ', 'ํฉ๋๋ค', '.', '[SEP]'])
[2, 4958, 6855, 2046, 7088, 1050, 7843, 54, 3]
Reference
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file kobert_transformers-0.6.0.tar.gz.
File metadata
- Download URL: kobert_transformers-0.6.0.tar.gz
- Upload date:
- Size: 14.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.8.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
47ecd26031e1ed500645d0bb7f773bcdc43086640f264b4eaef2032cdf49120c
|
|
| MD5 |
ab1918d37a8d10743757f28a6b842b93
|
|
| BLAKE2b-256 |
e37caa6dd2025bf09fa235614962a8d0d7cb27ea739985f5788a05105200b7fb
|
File details
Details for the file kobert_transformers-0.6.0-py3-none-any.whl.
File metadata
- Download URL: kobert_transformers-0.6.0-py3-none-any.whl
- Upload date:
- Size: 12.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.8.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4d5c170b53ff5256f0c8bffa98f2b2554f1fd4b0d38c3fdc549a17df5a9adb4f
|
|
| MD5 |
b9c99890ebfa8a3bbb43e230396bdb8a
|
|
| BLAKE2b-256 |
77af2a85216d5a4faf2d29fa8325cfdda9f29f8b4d3ad56040162dfb8fca6992
|