Transformer-based models implemented in tensorflow 2.x(Keras)
Project description
transformers-keras
Transformer-based models implemented in tensorflow 2.x(Keras).
Installation
pip install -U transformers-keras
Models
- Transformer
- Attention Is All You Need.
- Here is a tutorial from tensorflow:Transformer model for language understanding
- BERT
- ALBERT
Transformer
Train a new transformer:
from transformers_keras import TransformerTextFileDatasetBuilder
from transformers_keras import TransformerDefaultTokenizer
from transformers_keras import TransformerRunner
src_tokenizer = TransformerDefaultTokenizer(vocab_file='testdata/vocab_src.txt')
tgt_tokenizer = TransformerDefaultTokenizer(vocab_file='testdata/vocab_tgt.txt')
dataset_builder = TransformerTextFileDatasetBuilder(src_tokenizer, tgt_tokenizer)
model_config = {
'num_encoder_layers': 2,
'num_decoder_layers': 2,
'src_vocab_size': src_tokenizer.vocab_size,
'tgt_vocab_size': tgt_tokenizer.vocab_size,
}
runner = TransformerRunner(model_config, dataset_builder, model_dir='/tmp/transformer')
train_files = [('testdata/train.src.txt','testdata/train.tgt.txt')]
runner.train(train_files, epochs=10, callbacks=None)
BERT
Use your own data to pretrain a BERT model.
from transformers_keras import BertTFRecordDatasetBuilder
from transformers_keras import BertRunner
dataset_builder = BertTFRecordDatasetBuilder()
model_config = {
'num_layers': 6,
'vocab_size': 100, # Caution: use the correct vocab_size
}
runner = BertRunner(model_config, dataset_builder, model_dir='/tmp/bert')
train_files = ['testdata/bert_custom_pretrain.tfrecord']
runner.train(train_files, epochs=10, callbacks=None)
Tips:
You need prepare your data to tfrecord format. You can use this script: create_pretraining_data.py
You can subclass
transformers_keras.tokenizers.BertTFRecordDatasetBuilder
to parse custom tfrecord examples as you need.
ALBERT
You should process your data to tfrecord format. Modify this script transformers_keras/utils/bert_tfrecord_custom_generator.py
as you need.
from transformers_keras import BertTFRecordDatasetBuilder
from transformers_keras import AlbertRunner
# ALBERT has the same data format with BERT
dataset_builder = BertTFRecordDatasetBuilder()
model_config = {
'num_layers': 6,
'num_groups': 1,
'num_layers_each_group': 1,
'vocab_size': 100, # Caution: use the correct vocab_size
}
runner = AlbertRunner(model_config, dataset_builder, model_dir='/tmp/albert')
train_files = ['testdata/bert_custom_pretrain.tfrecord']
runner.train(train_files, epochs=10, callbacks=None)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
transformers_keras-0.1.1.tar.gz
(34.1 kB
view hashes)
Built Distribution
Close
Hashes for transformers_keras-0.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | eeddefd709da9cb00a466387836bb148f67f7893bd4607b9622efa4a0399ea5d |
|
MD5 | 83dc4df41041074b27c5aaf05bfaeb0a |
|
BLAKE2b-256 | de48202c3adf4430243bab9e0c99acdaa702279b3f39ab0c7f12aeee6e9cf9d1 |