Transformer-based models implemented in tensorflow 2.x(Keras)
Project description
transformers-keras
Transformer-based models implemented in tensorflow 2.x(Keras).
Installation
pip install -U transformers-keras
Models
- Transformer[DELETED]
- Attention Is All You Need.
- Here is a tutorial from tensorflow:Transformer model for language understanding
- BERT
- ALBERT
BERT
Supported pretrained models:
- All the BERT models pretrained by google-research/bert
- All the BERT & RoBERTa models pretrained by ymcui/Chinese-BERT-wwm
Feature Extraction Examples:
from transformers_keras import Bert
# Used to predict directly
model = Bert.from_pretrained('/path/to/pretrained/bert/model')
# segment_ids and mask inputs are optional
input_ids = tf.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
segment_ids, attention_mask = None, None
sequence_outputs, pooled_output = model(input_ids, segment_ids, attention_mask, training=False)
Also, you can optionally get the hidden states and attention weights of each encoder layer:
from transformers_keras import Bert
# Used to predict directly
model = Bert.from_pretrained(
'/path/to/pretrained/bert/model',
return_states=True,
return_attention_weights=True)
input_ids = tf.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
segment_ids, attention_mask = None, None
sequence_outputs, pooled_output, states, attn_weights = model(input_ids, segment_ids, attention_mask, training=False)
Fine-tuning Examples
# Used to fine-tuning
def build_bert_classify_model(pretrained_model_dir, trainable=True, **kwargs):
input_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_ids')
# segment_ids and mask inputs are optional
segment_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='segment_ids')
bert = Bert.from_pretrained(pretrained_model_dir, **kwargs)
bert.trainable = trainable
sequence_outputs, pooled_output = bert(input_ids, segment_ids, None)
outputs = tf.keras.layers.Dense(2, name='output')(pooled_output)
model = tf.keras.Model(inputs=[input_ids, segment_ids], outputs=outputs)
model.compile(loss='binary_cross_entropy', optimizer='adam')
return model
model = build_bert_classify_model(
pretrained_model_dir=os.path.join(BASE_DIR, 'chinese_wwm_ext_L-12_H-768_A-12'),
trainable=True)
model.summary()
ALBERT
Supported pretrained models:
- All the ALBERT models pretrained by google-research/albert
Feature Extraction Examples
from transformers_keras import Albert
# Used to predict directly
model = Albert.from_pretrained('/path/to/pretrained/albert/model')
input_ids = tf.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
segment_ids, attention_mask = None, None
sequence_outputs, pooled_output = model(input_ids, segment_ids, attention_mask, training=False)
Also, you can optionally get the hidden states and attention weights of each encoder layer:
from transformers_keras import Albert
# Used to predict directly
model = Albert.from_pretrained(
'/path/to/pretrained/albert/model',
return_states=True,
return_attention_weights=True)
# segment_ids and mask inputs are optional
input_ids = tf.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
segment_ids, attention_mask = None, None
sequence_outputs, pooled_output, states, attn_weights = model(input_ids, segment_ids, mask, training=False)
Fine-tuing Examples
# Used to fine-tuning
def build_albert_classify_model(pretrained_model_dir, trainable=True, **kwargs):
input_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_ids')
# segment_ids and mask inputs are optional
segment_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='segment_ids')
albert = Albert.from_pretrained(pretrained_model_dir, **kwargs)
albert.trainable = trainable
sequence_outputs, pooled_output = albert(input_ids, segment_ids, None)
outputs = tf.keras.layers.Dense(2, name='output')(pooled_output)
model = tf.keras.Model(inputs=[input_ids, segment_ids], outputs=outputs)
model.compile(loss='binary_cross_entropy', optimizer='adam')
return model
model = build_albert_classify_model(
pretrained_model_dir=os.path.join(BASE_DIR, 'albert_base'),
trainable=True)
model.summary()
Advanced Usage
Here are some advanced usages:
- Skip loadding weights from checkpoint
- Load other pretrained models
Skip loadding weights from checkpoint
You can skip loadding some weights from ckpt.
Examples:
from transformers_keras import Bert, Albert
ALBERT_MODEL_PATH = '/path/to/albert/model'
albert = Albert.from_pretrained(
ALBERT_MODEL_PATH,
# return_states=False,
# return_attention_weights=False,
skip_token_embedding=True,
skip_position_embedding=True,
skip_segment_embedding=True,
skip_pooler=True,
...
)
BERT_MODEL_PATH = '/path/to/bert/model'
bert = Bert.from_pretrained(
BERT_MODEL_PATH,
# return_states=False,
# return_attention_weights=False,
skip_token_embedding=True,
skip_position_embedding=True,
skip_segment_embedding=True,
skip_pooler=True,
...
)
All supported kwargs to skip loadding weights:
skip_token_embedding
, skip loaddingtoken_embedding
weights from ckptskip_position_embedding
, skip loaddingposition_embedding
weights from ckptskip_segment_embedding
, skip loaddingtoken_type_emebdding
weights from ckptskip_embedding_layernorm
, skip loaddinglayer_norm
weights of emebedding layer from ckptskip_pooler
, skip loaddingpooler
weights of pooler layer from ckpt
Load other pretrained models
If you want to load models pretrained by other implementationds, whose config and trainable weights are a little different from previous, you can subclass AbstractAdapter
to adapte these models:
from transformers_keras.adapters import AbstractAdapter
from transformers_keras import Bert, Albert
# load custom bert models
class MyBertAdapter(AbstractAdapter):
def adapte_config(self, config_file, **kwargs):
# adapte model config here
# you can refer to `transformers_keras.adapters.bert_adapter`
pass
def adapte_weights(self, model, config, ckpt, **kwargs):
# adapte model weights here
# you can refer to `transformers_keras.adapters.bert_adapter`
pass
bert = Bert.from_pretrained('/path/to/your/bert/model', adapter=MyBertAdapter())
# or, load custom albert models
class MyAlbertAdapter(AbstractAdapter):
def adapte_config(self, config_file, **kwargs):
# adapte model config here
# you can refer to `transformers_keras.adapters.albert_adapter`
pass
def adapte_weights(self, model, config, ckpt, **kwargs):
# adapte model weights here
# you can refer to `transformers_keras.adapters.albert_adapter`
pass
albert = Albert.from_pretrained('/path/to/your/albert/model', adapter=MyAlbertAdapter())
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for transformers_keras-0.2.4-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 46ceb7a2ba9e1e82aca76fc76ff80e08d2255fe9223215cda39345f08585b125 |
|
MD5 | 5d8dd5a970b5698bfac570a121b07e18 |
|
BLAKE2b-256 | 9f10a865f8b6417f6ebefdc4072159c3f701b1b1cd4b75342c6cbbd4c1570c7d |