Skip to main content

BERT implemented in Keras of Tensorflow package on TPU

Project description

Travis Coverage

This is a fork of CyberZHG/keras_bert which supports Keras BERT on TPU.

Implementation of the BERT. Official pre-trained models could be loaded for feature extraction and prediction.

Colab Demo

HighCWu/keras-bert-tpu

Install

pip install keras-bert-tpu

Usage

Load Official Pre-trained Models

In feature extraction demo, you should be able to get the same extraction result as the official model. And in prediction demo, the missing word in the sentence could be predicted.

Train & Use

from keras_bert import get_base_dict, get_model, gen_batch_inputs


# A toy input example
sentence_pairs = [
    [['all', 'work', 'and', 'no', 'play'], ['makes', 'jack', 'a', 'dull', 'boy']],
    [['from', 'the', 'day', 'forth'], ['my', 'arm', 'changed']],
    [['and', 'a', 'voice', 'echoed'], ['power', 'give', 'me', 'more', 'power']],
]


# Build token dictionary
token_dict = get_base_dict()  # A dict that contains some special tokens
for pairs in sentence_pairs:
    for token in pairs[0] + pairs[1]:
        if token not in token_dict:
            token_dict[token] = len(token_dict)
token_list = list(token_dict.keys())  # Used for selecting a random word


# Build & train the model
model = get_model(
    token_num=len(token_dict),
    head_num=5,
    transformer_num=12,
    embed_dim=25,
    feed_forward_dim=100,
    seq_len=20,
    pos_num=20,
    dropout_rate=0.05,
)
model.summary()

def _generator():
    while True:
        yield gen_batch_inputs(
            sentence_pairs,
            token_dict,
            token_list,
            seq_len=20,
            mask_rate=0.3,
            swap_sentence_rate=1.0,
        )

model.fit_generator(
    generator=_generator(),
    steps_per_epoch=1000,
    epochs=100,
    validation_data=_generator(),
    validation_steps=100,
    callbacks=[
        keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
    ],
)


# Use the trained model
inputs, output_layer = get_model(  # `output_layer` is the last feature extraction layer (the last transformer)
    token_num=len(token_dict),
    head_num=5,
    transformer_num=12,
    embed_dim=25,
    feed_forward_dim=100,
    seq_len=20,
    pos_num=20,
    dropout_rate=0.05,
    training=False,  # The input layers and output layer will be returned if `training` is `False`
)

Custom Feature Extraction

def _custom_layers(x, trainable=True):
    return keras.layers.LSTM(
        units=768,
        trainable=trainable,
        name='LSTM',
    )(x)

model = get_model(
    token_num=200,
    embed_dim=768,
    custom_layers=_custom_layers,
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-bert-tpu-0.1.7.tar.gz (17.5 kB view details)

Uploaded Source

File details

Details for the file keras-bert-tpu-0.1.7.tar.gz.

File metadata

  • Download URL: keras-bert-tpu-0.1.7.tar.gz
  • Upload date:
  • Size: 17.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.18.4 setuptools/36.5.0.post20170921 requests-toolbelt/0.8.0 tqdm/4.25.0 CPython/3.6.3

File hashes

Hashes for keras-bert-tpu-0.1.7.tar.gz
Algorithm Hash digest
SHA256 8ecd1a95c82d9bca2d2e1f5b29e06a0a15eb4a5c468784f613b02b17e3124f1f
MD5 154f57505e1d08977d5731423c6b20cd
BLAKE2b-256 e0d14a5d9535ec23272ab70fe7dc89b391e6bfc19a09b38cfff3a21ffeb4ca07

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page