Skip to main content

Transformer-based models implemented in tensorflow 2.x(Keras)

Project description

transformers-keras

Python package PyPI version Python

Transformer-based models implemented in tensorflow 2.x(Keras).

Installation

pip install -U transformers-keras

Models

BERT

Supported pretrained models:

from transformers_keras import Bert

# Used to predict directly
model = Bert.from_pretrained('/path/to/pretrained/bert/model')
# segment_ids and mask inputs are optional
model.predict((input_ids, segment_ids, mask))
# or
model(inputs=(input_ids, segment_ids, mask))

# Used to fine-tuning
def build_bert_classify_model(trainable=True):
    input_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_ids')
    # segment_ids and mask inputs are optional
    segment_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='segment_ids')

    bert = Bert.from_pretrained(os.path.join(BASE_DIR, 'chinese_wwm_ext_L-12_H-768_A-12'))
    bert.trainable = trainable

    _, pooled_output, _, _ = bert(inputs=(input_ids, segment_ids))
    outputs = tf.keras.layers.Dense(2, name='output')(pooled_output)
    model = tf.keras.Model(inputs=[input_ids, segment_ids], outputs=outputs)
    model.compile(loss='binary_cross_entropy', optimizer='adam')
    return model

model = build_bert_classify_model()
model.summary()

ALBERT

Supported pretrained models:

from transformers_keras import Albert

# Used to predict directly
model = Bert.from_pretrained('/path/to/pretrained/albert/model')
# segment_ids and mask inputs are optional
model.predict((input_ids, segment_ids, mask))
# or
model(inputs=(input_ids, segment_ids, mask))

# Used to fine-tuning 
def build_albert_classify_model(trainable=True):
    input_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_ids')
    # segment_ids and mask inputs are optional
    segment_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='segment_ids')

    albert = Albert.from_pretrained(os.path.join(BASE_DIR, 'albert_large_zh'))
    albert.trainable = trainable

    _, pooled_output, _, _ = albert(inputs=(input_ids, segment_ids))
    outputs = tf.keras.layers.Dense(2, name='output')(pooled_output)
    model = tf.keras.Model(inputs=[input_ids, segment_ids], outputs=outputs)
    model.compile(loss='binary_cross_entropy', optimizer='adam')
    return model

model = build_albert_classify_model()
model.summary()

Load other pretrained models

If you want to load models pretrained by other implementationds, whose config and trainable weights are a little different from previous, you can subclass AbstractAdapter to adapte these models:

from transformers_keras.adapters import AbstractAdapter
from transformers_keras import Bert, Albert

# load custom bert models
class MyBertAdapter(AbstractAdapter):

    def adapte_config(self, config_file, **kwargs):
        # adapte model config here
        # you can refer to `transformers_keras.adapters.bert_adapter`
        pass

    def adapte_weights(self, model, config, ckpt, **kwargs):
        # adapte model weights here
        # you can refer to `transformers_keras.adapters.bert_adapter`
        pass

bert = Bert.from_pretrained('/path/to/your/bert/model', adapter=MyBertAdapter())

# or, load custom albert models
class MyAlbertAdapter(AbstractAdapter):

    def adapte_config(self, config_file, **kwargs):
        # adapte model config here
        # you can refer to `transformers_keras.adapters.albert_adapter`
        pass

    def adapte_weights(self, model, config, ckpt, **kwargs):
        # adapte model weights here
        # you can refer to `transformers_keras.adapters.albert_adapter`
        pass

albert = Albert.from_pretrained('/path/to/your/albert/model', adapter=MyAlbertAdapter())

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for transformers-keras, version 0.2.1
Filename, size File type Python version Upload date Hashes
Filename, size transformers_keras-0.2.1-py3-none-any.whl (28.1 kB) File type Wheel Python version py3 Upload date Hashes View
Filename, size transformers_keras-0.2.1.tar.gz (16.2 kB) File type Source Python version None Upload date Hashes View

Supported by

Pingdom Pingdom Monitoring Google Google Object Storage and Download Analytics Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page