Transformer-based models implemented in tensorflow 2.x(Keras)
Project description
transformers-keras
Transformer-based models implemented in tensorflow 2.x(Keras).
Installation
pip install -U transformers-keras
Models
- Transformer[DELETED]
- Attention Is All You Need.
- Here is a tutorial from tensorflow:Transformer model for language understanding
- BERT
- ALBERT
BERT
Supported pretrained models:
- All the BERT models pretrained by google-research/bert
- All the BERT & RoBERTa models pretrained by ymcui/Chinese-BERT-wwm
Feature Extraction Examples:
from transformers_keras import Bert
# Used to predict directly
model = Bert.from_pretrained('/path/to/pretrained/bert/model')
# segment_ids and mask inputs are optional
sequence_outputs, pooled_output = model.predict((input_ids, segment_ids, mask))
# or
sequence_outputs, pooled_output = model(inputs=(input_ids, segment_ids, mask))
Also, you can optionally get the hidden states and attention weights of each encoder layer:
from transformers_keras import Bert
# Used to predict directly
model = Bert.from_pretrained(
'/path/to/pretrained/bert/model',
return_states=True,
return_attention_weights=True)
# segment_ids and mask inputs are optional
sequence_outputs, pooled_output, states, attn_weights = model.predict((input_ids, segment_ids, mask))
# or
sequence_outputs, pooled_output, states, attn_weights = model(inputs=(input_ids, segment_ids, mask))
Fine-tuning Examples
# Used to fine-tuning
def build_bert_classify_model(pretrained_model_dir, trainable=True, **kwargs):
input_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_ids')
# segment_ids and mask inputs are optional
segment_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='segment_ids')
bert = Bert.from_pretrained(pretrained_model_dir, **kwargs)
bert.trainable = trainable
sequence_outputs, pooled_output = bert(inputs=(input_ids, segment_ids))
outputs = tf.keras.layers.Dense(2, name='output')(pooled_output)
model = tf.keras.Model(inputs=[input_ids, segment_ids], outputs=outputs)
model.compile(loss='binary_cross_entropy', optimizer='adam')
return model
model = build_bert_classify_model(
pretrained_model_dir=os.path.join(BASE_DIR, 'chinese_wwm_ext_L-12_H-768_A-12'),
trainable=True)
model.summary()
ALBERT
Supported pretrained models:
- All the ALBERT models pretrained by google-research/albert
Feature Extraction Examples
from transformers_keras import Albert
# Used to predict directly
model = Albert.from_pretrained('/path/to/pretrained/albert/model')
# segment_ids and mask inputs are optional
sequence_outputs, pooled_output = model.predict((input_ids, segment_ids, mask))
# or
sequence_outputs, pooled_output = model(inputs=(input_ids, segment_ids, mask))
Also, you can optionally get the hidden states and attention weights of each encoder layer:
from transformers_keras import Albert
# Used to predict directly
model = Albert.from_pretrained(
'/path/to/pretrained/albert/model',
return_states=True,
return_attention_weights=True)
# segment_ids and mask inputs are optional
sequence_outputs, pooled_output, states, attn_weights = model.predict((input_ids, segment_ids, mask))
# or
sequence_outputs, pooled_output, states, attn_weights = model(inputs=(input_ids, segment_ids, mask))
Fine-tuing Examples
# Used to fine-tuning
def build_albert_classify_model(pretrained_model_dir, trainable=True, **kwargs):
input_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_ids')
# segment_ids and mask inputs are optional
segment_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='segment_ids')
albert = Albert.from_pretrained(pretrained_model_dir, **kwargs)
albert.trainable = trainable
sequence_outputs, pooled_output = albert(inputs=(input_ids, segment_ids))
outputs = tf.keras.layers.Dense(2, name='output')(pooled_output)
model = tf.keras.Model(inputs=[input_ids, segment_ids], outputs=outputs)
model.compile(loss='binary_cross_entropy', optimizer='adam')
return model
model = build_albert_classify_model(
pretrained_model_dir=os.path.join(BASE_DIR, 'albert_base'),
trainable=True)
model.summary()
Advanced Usage
Here are some advanced usages:
- Skip loadding weights from checkpoint
- Load other pretrained models
Skip loadding weights from checkpoint
You can skip loadding some weights from ckpt.
Examples:
from transformers_keras import Bert, Albert
ALBERT_MODEL_PATH = '/path/to/albert/model'
albert = Albert.from_pretrained(
ALBERT_MODEL_PATH,
# return_states=False,
# return_attention_weights=False,
skip_token_embedding=True,
skip_position_embedding=True,
skip_segment_embedding=True,
skip_pooler=True,
...
)
BERT_MODEL_PATH = '/path/to/bert/model'
bert = Bert.from_pretrained(
BERT_MODEL_PATH,
# return_states=False,
# return_attention_weights=False,
skip_token_embedding=True,
skip_position_embedding=True,
skip_segment_embedding=True,
skip_pooler=True,
...
)
All supported kwargs to skip loadding weights:
skip_token_embedding
, skip loaddingtoken_embedding
weights from ckptskip_position_embedding
, skip loaddingposition_embedding
weights from ckptskip_segment_embedding
, skip loaddingtoken_type_emebdding
weights from ckptskip_embedding_layernorm
, skip loaddinglayer_norm
weights of emebedding layer from ckptskip_pooler
, skip loaddingpooler
weights of pooler layer from ckpt
Load other pretrained models
If you want to load models pretrained by other implementationds, whose config and trainable weights are a little different from previous, you can subclass AbstractAdapter
to adapte these models:
from transformers_keras.adapters import AbstractAdapter
from transformers_keras import Bert, Albert
# load custom bert models
class MyBertAdapter(AbstractAdapter):
def adapte_config(self, config_file, **kwargs):
# adapte model config here
# you can refer to `transformers_keras.adapters.bert_adapter`
pass
def adapte_weights(self, model, config, ckpt, **kwargs):
# adapte model weights here
# you can refer to `transformers_keras.adapters.bert_adapter`
pass
bert = Bert.from_pretrained('/path/to/your/bert/model', adapter=MyBertAdapter())
# or, load custom albert models
class MyAlbertAdapter(AbstractAdapter):
def adapte_config(self, config_file, **kwargs):
# adapte model config here
# you can refer to `transformers_keras.adapters.albert_adapter`
pass
def adapte_weights(self, model, config, ckpt, **kwargs):
# adapte model weights here
# you can refer to `transformers_keras.adapters.albert_adapter`
pass
albert = Albert.from_pretrained('/path/to/your/albert/model', adapter=MyAlbertAdapter())
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for transformers_keras-0.2.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2accb3fea54068f0bab35ffaee2744b503df18f255704add217df75f89f9c3e3 |
|
MD5 | 9e034dbe559f031fb9d9a1239fcb6ba6 |
|
BLAKE2b-256 | 971825da930cac6b4da88dd31b32cad6e2482383a2529c69b220c48d87cb5c07 |