Skip to main content

BERT implemented in Keras

Project description

Travis Coverage

Implementation of the BERT. Official pre-trained models could be loaded for feature extraction.

Install

pip install keras-bert

Usage

Load Official Pre-trained Models

See load model demo. You should be able to get the same feature extraction result as the official model.

Train & Use

from keras_bert import get_base_dict, get_model, gen_batch_inputs


# A toy input example
sentence_pairs = [
    [['all', 'work', 'and', 'no', 'play'], ['makes', 'jack', 'a', 'dull', 'boy']],
    [['from', 'the', 'day', 'forth'], ['my', 'arm', 'changed']],
    [['and', 'a', 'voice', 'echoed'], ['power', 'give', 'me', 'more', 'power']],
]


# Build token dictionary
token_dict = get_base_dict()  # A dict that contains some special tokens
for pairs in sentence_pairs:
    for token in pairs[0] + pairs[1]:
        if token not in token_dict:
            token_dict[token] = len(token_dict)
token_list = list(token_dict.keys())  # Used for selecting a random word


# Build & train the model
model = get_model(
    token_num=len(token_dict),
    head_num=5,
    transformer_num=12,
    embed_dim=25,
    feed_forward_dim=100,
    seq_len=20,
    pos_num=20,
    dropout=0.05,
)
model.summary()

def _generator():
    while True:
        yield gen_batch_inputs(
            sentence_pairs,
            token_dict,
            token_list,
            seq_len=20,
            mask_rate=0.3,
            swap_sentence_rate=1.0,
        )

model.fit_generator(
    generator=_generator(),
    steps_per_epoch=1000,
    epochs=100,
    validation_data=_generator(),
    validation_steps=100,
    callbacks=[
        keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
    ],
)


# Use the trained model
inputs, output_layer = get_model(  # `output_layer` is the last feature extraction layer (the last transformer)
    token_num=len(token_dict),
    head_num=5,
    transformer_num=12,
    embed_dim=25,
    feed_forward_dim=100,
    seq_len=20,
    pos_num=20,
    dropout=0.05,
    training=False,  # The input layers and output layer will be returned if `training` is `False`
)

Custom Feature Extraction

def _custom_layers(x, trainable=True):
    return keras.layers.LSTM(
        units=768,
        trainable=trainable,
        name='LSTM',
    )(x)

model = get_model(
    token_num=200,
    embed_dim=768,
    custom_layers=_custom_layers,
)

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-bert-0.16.0.tar.gz (9.8 kB view details)

Uploaded Source

File details

Details for the file keras-bert-0.16.0.tar.gz.

File metadata

  • Download URL: keras-bert-0.16.0.tar.gz
  • Upload date:
  • Size: 9.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.11.0 pkginfo/1.4.2 requests/2.18.4 setuptools/28.8.0 requests-toolbelt/0.8.0 tqdm/4.24.0 CPython/3.6.4

File hashes

Hashes for keras-bert-0.16.0.tar.gz
Algorithm Hash digest
SHA256 d736e27b5e3c184180f7de70eddc88ab19e227b67ff96090aaa119f910aaf4ef
MD5 6e2f324cbac7c91ea89eb9748edb631a
BLAKE2b-256 f4464e6bc99a3fd5f150c93e2cef60206caf9616c46fb2a65b600f0361387a47

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page