Transformer-based models implemented in tensorflow 2.x(Keras)
Project description
transformers-keras
Transformer-based models implemented in tensorflow 2.x(Keras).
Installation
pip install -U transformers-keras
Models
- Transformer[DELETED]
- Attention Is All You Need.
- Here is a tutorial from tensorflow:Transformer model for language understanding
- BERT
- ALBERT
BERT
Supported pretrained models:
- All the BERT models pretrained by google-research/bert
- All the BERT & RoBERTa models pretrained by ymcui/Chinese-BERT-wwm
from transformers_keras import Bert
# Used to predict directly
model = Bert.from_pretrained('/path/to/pretrained/bert/model')
# segment_ids and mask inputs are optional
model.predict((input_ids, segment_ids, mask))
# or
model(inputs=(input_ids, segment_ids, mask))
# Used to fine-tuning
def build_bert_classify_model(trainable=True):
input_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_ids')
# segment_ids and mask inputs are optional
segment_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='segment_ids')
bert = Bert.from_pretrained(os.path.join(BASE_DIR, 'chinese_wwm_ext_L-12_H-768_A-12'))
bert.trainable = trainable
_, pooled_output, _, _ = bert(inputs=(input_ids, segment_ids))
outputs = tf.keras.layers.Dense(2, name='output')(pooled_output)
model = tf.keras.Model(inputs=[input_ids, segment_ids], outputs=outputs)
model.compile(loss='binary_cross_entropy', optimizer='adam')
return model
model = build_bert_classify_model()
model.summary()
ALBERT
Supported pretrained models:
- All the ALBERT models pretrained by google-research/albert
from transformers_keras import Albert
# Used to predict directly
model = Bert.from_pretrained('/path/to/pretrained/albert/model')
# segment_ids and mask inputs are optional
model.predict((input_ids, segment_ids, mask))
# or
model(inputs=(input_ids, segment_ids, mask))
# Used to fine-tuning
def build_albert_classify_model(trainable=True):
input_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_ids')
# segment_ids and mask inputs are optional
segment_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='segment_ids')
albert = Albert.from_pretrained(os.path.join(BASE_DIR, 'albert_large_zh'))
albert.trainable = trainable
_, pooled_output, _, _ = albert(inputs=(input_ids, segment_ids))
outputs = tf.keras.layers.Dense(2, name='output')(pooled_output)
model = tf.keras.Model(inputs=[input_ids, segment_ids], outputs=outputs)
model.compile(loss='binary_cross_entropy', optimizer='adam')
return model
model = build_albert_classify_model()
model.summary()
Load other pretrained models
If you want to load models pretrained by other implementationds, whose config and trainable weights are a little different from previous, you can subclass AbstractAdapter
to adapte these models:
from transformers_keras.adapters import AbstractAdapter
from transformers_keras import Bert, Albert
# load custom bert models
class MyBertAdapter(AbstractAdapter):
def adapte_config(self, config_file, **kwargs):
# adapte model config here
# you can refer to `transformers_keras.adapters.bert_adapter`
pass
def adapte_weights(self, model, config, ckpt, **kwargs):
# adapte model weights here
# you can refer to `transformers_keras.adapters.bert_adapter`
pass
bert = Bert.from_pretrained('/path/to/your/bert/model', adapter=MyBertAdapter())
# or, load custom albert models
class MyAlbertAdapter(AbstractAdapter):
def adapte_config(self, config_file, **kwargs):
# adapte model config here
# you can refer to `transformers_keras.adapters.albert_adapter`
pass
def adapte_weights(self, model, config, ckpt, **kwargs):
# adapte model weights here
# you can refer to `transformers_keras.adapters.albert_adapter`
pass
albert = Albert.from_pretrained('/path/to/your/albert/model', adapter=MyAlbertAdapter())
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
transformers_keras-0.2.1.tar.gz
(16.2 kB
view hashes)
Built Distribution
Close
Hashes for transformers_keras-0.2.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | be778ccdfa695aa751dcb73f02a97743e55a6fab8ef68bad7269ee67a69bed7a |
|
MD5 | c351a5e7a715e56ea3cef0c758176beb |
|
BLAKE2b-256 | 3272873a921ef77c26206f7f997e0a1841e36f4a6f375966c691774370aeae08 |