Transformer-based models implemented in tensorflow 2.x(Keras)
Project description
transformers-keras
Transformer-based models implemented in tensorflow 2.x(Keras).
Installation
pip install -U transformers-keras
Models
- Transformer[DELETED]
- Attention Is All You Need.
- Here is a tutorial from tensorflow:Transformer model for language understanding
- BERT
- ALBERT
BERT
Supported pretrained models:
- All the BERT models pretrained by google-research/bert
- All the BERT & RoBERTa models pretrained by ymcui/Chinese-BERT-wwm
Feature Extraction Examples:
from transformers_keras import Bert
# Used to predict directly
model = Bert.from_pretrained('/path/to/pretrained/bert/model')
# segment_ids and mask inputs are optional
input_ids = tf.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
segment_ids, attention_mask = None, None
sequence_outputs, pooled_output = model(input_ids, segment_ids, attention_mask, training=False)
Also, you can optionally get the hidden states and attention weights of each encoder layer:
from transformers_keras import Bert
# Used to predict directly
model = Bert.from_pretrained(
'/path/to/pretrained/bert/model',
return_states=True,
return_attention_weights=True)
input_ids = tf.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
segment_ids, attention_mask = None, None
sequence_outputs, pooled_output, states, attn_weights = model(input_ids, segment_ids, attention_mask, training=False)
Fine-tuning Examples
# Used to fine-tuning
def build_bert_classify_model(pretrained_model_dir, trainable=True, **kwargs):
input_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_ids')
# segment_ids and mask inputs are optional
segment_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='segment_ids')
bert = Bert.from_pretrained(pretrained_model_dir, **kwargs)
bert.trainable = trainable
sequence_outputs, pooled_output = bert(input_ids, segment_ids, None)
outputs = tf.keras.layers.Dense(2, name='output')(pooled_output)
model = tf.keras.Model(inputs=[input_ids, segment_ids], outputs=outputs)
model.compile(loss='binary_cross_entropy', optimizer='adam')
return model
model = build_bert_classify_model(
pretrained_model_dir=os.path.join(BASE_DIR, 'chinese_wwm_ext_L-12_H-768_A-12'),
trainable=True)
model.summary()
ALBERT
Supported pretrained models:
- All the ALBERT models pretrained by google-research/albert
Feature Extraction Examples
from transformers_keras import Albert
# Used to predict directly
model = Albert.from_pretrained('/path/to/pretrained/albert/model')
input_ids = tf.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
segment_ids, attention_mask = None, None
sequence_outputs, pooled_output = model(input_ids, segment_ids, attention_mask, training=False)
Also, you can optionally get the hidden states and attention weights of each encoder layer:
from transformers_keras import Albert
# Used to predict directly
model = Albert.from_pretrained(
'/path/to/pretrained/albert/model',
return_states=True,
return_attention_weights=True)
# segment_ids and mask inputs are optional
input_ids = tf.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
segment_ids, attention_mask = None, None
sequence_outputs, pooled_output, states, attn_weights = model(input_ids, segment_ids, mask, training=False)
Fine-tuing Examples
# Used to fine-tuning
def build_albert_classify_model(pretrained_model_dir, trainable=True, **kwargs):
input_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_ids')
# segment_ids and mask inputs are optional
segment_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='segment_ids')
albert = Albert.from_pretrained(pretrained_model_dir, **kwargs)
albert.trainable = trainable
sequence_outputs, pooled_output = albert(input_ids, segment_ids, None)
outputs = tf.keras.layers.Dense(2, name='output')(pooled_output)
model = tf.keras.Model(inputs=[input_ids, segment_ids], outputs=outputs)
model.compile(loss='binary_cross_entropy', optimizer='adam')
return model
model = build_albert_classify_model(
pretrained_model_dir=os.path.join(BASE_DIR, 'albert_base'),
trainable=True)
model.summary()
Advanced Usage
Here are some advanced usages:
- Skip loadding weights from checkpoint
- Load other pretrained models
Skip loadding weights from checkpoint
You can skip loadding some weights from ckpt.
Examples:
from transformers_keras import Bert, Albert
ALBERT_MODEL_PATH = '/path/to/albert/model'
albert = Albert.from_pretrained(
ALBERT_MODEL_PATH,
# return_states=False,
# return_attention_weights=False,
skip_token_embedding=True,
skip_position_embedding=True,
skip_segment_embedding=True,
skip_pooler=True,
...
)
BERT_MODEL_PATH = '/path/to/bert/model'
bert = Bert.from_pretrained(
BERT_MODEL_PATH,
# return_states=False,
# return_attention_weights=False,
skip_token_embedding=True,
skip_position_embedding=True,
skip_segment_embedding=True,
skip_pooler=True,
...
)
All supported kwargs to skip loadding weights:
skip_token_embedding
, skip loaddingtoken_embedding
weights from ckptskip_position_embedding
, skip loaddingposition_embedding
weights from ckptskip_segment_embedding
, skip loaddingtoken_type_emebdding
weights from ckptskip_embedding_layernorm
, skip loaddinglayer_norm
weights of emebedding layer from ckptskip_pooler
, skip loaddingpooler
weights of pooler layer from ckpt
Load other pretrained models
If you want to load models pretrained by other implementationds, whose config and trainable weights are a little different from previous, you can subclass AbstractAdapter
to adapte these models:
from transformers_keras.adapters import AbstractAdapter
from transformers_keras import Bert, Albert
# load custom bert models
class MyBertAdapter(AbstractAdapter):
def adapte_config(self, config_file, **kwargs):
# adapte model config here
# you can refer to `transformers_keras.adapters.bert_adapter`
pass
def adapte_weights(self, model, config, ckpt, **kwargs):
# adapte model weights here
# you can refer to `transformers_keras.adapters.bert_adapter`
pass
bert = Bert.from_pretrained('/path/to/your/bert/model', adapter=MyBertAdapter())
# or, load custom albert models
class MyAlbertAdapter(AbstractAdapter):
def adapte_config(self, config_file, **kwargs):
# adapte model config here
# you can refer to `transformers_keras.adapters.albert_adapter`
pass
def adapte_weights(self, model, config, ckpt, **kwargs):
# adapte model weights here
# you can refer to `transformers_keras.adapters.albert_adapter`
pass
albert = Albert.from_pretrained('/path/to/your/albert/model', adapter=MyAlbertAdapter())
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file transformers_keras-0.2.5.tar.gz
.
File metadata
- Download URL: transformers_keras-0.2.5.tar.gz
- Upload date:
- Size: 27.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.9.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
cc121b3421380590f22dd01425016664472da2ae87f29f5ed5074f66337319ce
|
|
MD5 |
5a4d8d20d98c0112757a4d10ed76ad99
|
|
BLAKE2b-256 |
96b8f3545d3d040faa926bc700157817cff87c169d9ae4ba72a0ee80bb72b129
|
File details
Details for the file transformers_keras-0.2.5-py3-none-any.whl
.
File metadata
- Download URL: transformers_keras-0.2.5-py3-none-any.whl
- Upload date:
- Size: 35.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.9.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
2ba85338d5e10e60fbca567598202c76a21d8be11f070e7be7c9de805c9acb16
|
|
MD5 |
a9a862477b5b762dfe842c7e39113f88
|
|
BLAKE2b-256 |
9552ad89ed89540642e8d565b8854b2c3537c071e0a89ad287dc70f339e93331
|