Skip to main content

Transformer-based models implemented in tensorflow 2.x(Keras)

Project description

transformers-keras

Python package PyPI version Python

Transformer-based models implemented in tensorflow 2.x(Keras).

Installation

pip install -U transformers-keras

Models

BERT

Supported pretrained models:

Feature Extraction Examples:

from transformers_keras import Bert

# Used to predict directly
model = Bert.from_pretrained('/path/to/pretrained/bert/model')
# segment_ids and mask inputs are optional
input_ids = tf.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
segment_ids, attention_mask = None, None
sequence_outputs, pooled_output = model(input_ids, segment_ids, attention_mask, training=False)

Also, you can optionally get the hidden states and attention weights of each encoder layer:

from transformers_keras import Bert

# Used to predict directly
model = Bert.from_pretrained(
    '/path/to/pretrained/bert/model', 
    return_states=True, 
    return_attention_weights=True)
input_ids = tf.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
segment_ids, attention_mask = None, None
sequence_outputs, pooled_output, states, attn_weights = model(input_ids, segment_ids, attention_mask, training=False)

Fine-tuning Examples

# Used to fine-tuning
def build_bert_classify_model(pretrained_model_dir, trainable=True, **kwargs):
    input_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_ids')
    # segment_ids and mask inputs are optional
    segment_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='segment_ids')

    bert = Bert.from_pretrained(pretrained_model_dir, **kwargs)
    bert.trainable = trainable

    sequence_outputs, pooled_output = bert(input_ids, segment_ids, None)
    outputs = tf.keras.layers.Dense(2, name='output')(pooled_output)
    model = tf.keras.Model(inputs=[input_ids, segment_ids], outputs=outputs)
    model.compile(loss='binary_cross_entropy', optimizer='adam')
    return model

model = build_bert_classify_model(
            pretrained_model_dir=os.path.join(BASE_DIR, 'chinese_wwm_ext_L-12_H-768_A-12'),
            trainable=True)
model.summary()

ALBERT

Supported pretrained models:

Feature Extraction Examples

from transformers_keras import Albert

# Used to predict directly
model = Albert.from_pretrained('/path/to/pretrained/albert/model')
input_ids = tf.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
segment_ids, attention_mask = None, None
sequence_outputs, pooled_output = model(input_ids, segment_ids, attention_mask, training=False)

Also, you can optionally get the hidden states and attention weights of each encoder layer:

from transformers_keras import Albert

# Used to predict directly
model = Albert.from_pretrained(
    '/path/to/pretrained/albert/model', 
    return_states=True, 
    return_attention_weights=True)
# segment_ids and mask inputs are optional
input_ids = tf.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
segment_ids, attention_mask = None, None
sequence_outputs, pooled_output, states, attn_weights = model(input_ids, segment_ids, mask, training=False)

Fine-tuing Examples

# Used to fine-tuning 
def build_albert_classify_model(pretrained_model_dir, trainable=True, **kwargs):
    input_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_ids')
    # segment_ids and mask inputs are optional
    segment_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='segment_ids')

    albert = Albert.from_pretrained(pretrained_model_dir, **kwargs)
    albert.trainable = trainable

    sequence_outputs, pooled_output = albert(input_ids, segment_ids, None)
    outputs = tf.keras.layers.Dense(2, name='output')(pooled_output)
    model = tf.keras.Model(inputs=[input_ids, segment_ids], outputs=outputs)
    model.compile(loss='binary_cross_entropy', optimizer='adam')
    return model

model = build_albert_classify_model(
            pretrained_model_dir=os.path.join(BASE_DIR, 'albert_base'),
            trainable=True)
model.summary()

Advanced Usage

Here are some advanced usages:

  • Skip loadding weights from checkpoint
  • Load other pretrained models

Skip loadding weights from checkpoint

You can skip loadding some weights from ckpt.

Examples:

from transformers_keras import Bert, Albert

ALBERT_MODEL_PATH = '/path/to/albert/model'
albert = Albert.from_pretrained(
    ALBERT_MODEL_PATH,
    # return_states=False,
    # return_attention_weights=False,
    skip_token_embedding=True,
    skip_position_embedding=True,
    skip_segment_embedding=True,
    skip_pooler=True,
    ...
    )

BERT_MODEL_PATH = '/path/to/bert/model'
bert = Bert.from_pretrained(
    BERT_MODEL_PATH,
    # return_states=False,
    # return_attention_weights=False,
    skip_token_embedding=True,
    skip_position_embedding=True,
    skip_segment_embedding=True,
    skip_pooler=True,
    ...
    )

All supported kwargs to skip loadding weights:

  • skip_token_embedding, skip loadding token_embedding weights from ckpt
  • skip_position_embedding, skip loadding position_embedding weights from ckpt
  • skip_segment_embedding, skip loadding token_type_emebdding weights from ckpt
  • skip_embedding_layernorm, skip loadding layer_norm weights of emebedding layer from ckpt
  • skip_pooler, skip loadding pooler weights of pooler layer from ckpt

Load other pretrained models

If you want to load models pretrained by other implementationds, whose config and trainable weights are a little different from previous, you can subclass AbstractAdapter to adapte these models:

from transformers_keras.adapters import AbstractAdapter
from transformers_keras import Bert, Albert

# load custom bert models
class MyBertAdapter(AbstractAdapter):

    def adapte_config(self, config_file, **kwargs):
        # adapte model config here
        # you can refer to `transformers_keras.adapters.bert_adapter`
        pass

    def adapte_weights(self, model, config, ckpt, **kwargs):
        # adapte model weights here
        # you can refer to `transformers_keras.adapters.bert_adapter`
        pass

bert = Bert.from_pretrained('/path/to/your/bert/model', adapter=MyBertAdapter())

# or, load custom albert models
class MyAlbertAdapter(AbstractAdapter):

    def adapte_config(self, config_file, **kwargs):
        # adapte model config here
        # you can refer to `transformers_keras.adapters.albert_adapter`
        pass

    def adapte_weights(self, model, config, ckpt, **kwargs):
        # adapte model weights here
        # you can refer to `transformers_keras.adapters.albert_adapter`
        pass

albert = Albert.from_pretrained('/path/to/your/albert/model', adapter=MyAlbertAdapter())

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformers_keras-0.2.5.tar.gz (27.2 kB view details)

Uploaded Source

Built Distribution

transformers_keras-0.2.5-py3-none-any.whl (35.4 kB view details)

Uploaded Python 3

File details

Details for the file transformers_keras-0.2.5.tar.gz.

File metadata

  • Download URL: transformers_keras-0.2.5.tar.gz
  • Upload date:
  • Size: 27.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.9.5

File hashes

Hashes for transformers_keras-0.2.5.tar.gz
Algorithm Hash digest
SHA256 cc121b3421380590f22dd01425016664472da2ae87f29f5ed5074f66337319ce
MD5 5a4d8d20d98c0112757a4d10ed76ad99
BLAKE2b-256 96b8f3545d3d040faa926bc700157817cff87c169d9ae4ba72a0ee80bb72b129

See more details on using hashes here.

File details

Details for the file transformers_keras-0.2.5-py3-none-any.whl.

File metadata

  • Download URL: transformers_keras-0.2.5-py3-none-any.whl
  • Upload date:
  • Size: 35.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.9.5

File hashes

Hashes for transformers_keras-0.2.5-py3-none-any.whl
Algorithm Hash digest
SHA256 2ba85338d5e10e60fbca567598202c76a21d8be11f070e7be7c9de805c9acb16
MD5 a9a862477b5b762dfe842c7e39113f88
BLAKE2b-256 9552ad89ed89540642e8d565b8854b2c3537c071e0a89ad287dc70f339e93331

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page