Skip to main content

Transformer implemented in Keras

Project description

Keras Transformer

Travis Coverage

Implementation of transformer for translation-like tasks.

Install

pip install keras-transformer

Usage

Train

import keras
import numpy as np
from keras_transformer import get_custom_objects, get_model, decode


# Build a small toy token dictionary
tokens = 'all work and no play makes jack a dull boy'.split(' ')
token_dict = {
    '<PAD>': 0,
    '<START>': 1,
    '<END>': 2,
}
for token in tokens:
    if token not in token_dict:
        token_dict[token] = len(token_dict)

# Generate toy data
encoder_inputs_no_padding = []
encoder_inputs, decoder_inputs, decoder_outputs = [], [], []
for i in range(1, len(tokens) - 1):
    encode_tokens, decode_tokens = tokens[:i], tokens[i:]
    encode_tokens = ['<START>'] + encode_tokens + ['<END>'] + ['<PAD>'] * (len(tokens) - len(encode_tokens))
    output_tokens = decode_tokens + ['<END>', '<PAD>'] + ['<PAD>'] * (len(tokens) - len(decode_tokens))
    decode_tokens = ['<START>'] + decode_tokens + ['<END>'] + ['<PAD>'] * (len(tokens) - len(decode_tokens))
    encode_tokens = list(map(lambda x: token_dict[x], encode_tokens))
    decode_tokens = list(map(lambda x: token_dict[x], decode_tokens))
    output_tokens = list(map(lambda x: [token_dict[x]], output_tokens))
    encoder_inputs_no_padding.append(encode_tokens[:i + 2])
    encoder_inputs.append(encode_tokens)
    decoder_inputs.append(decode_tokens)
    decoder_outputs.append(output_tokens)

# Build the model
model = get_model(
    token_num=len(token_dict),
    embed_dim=30,
    encoder_num=3,
    decoder_num=2,
    head_num=3,
    hidden_dim=120,
    attention_activation='relu',
    feed_forward_activation='relu',
    dropout_rate=0.05,
    embed_weights=np.random.random((13, 30)),
)
model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.sparse_categorical_crossentropy,
    metrics={},
    # Note: There is a bug in keras versions 2.2.3 and 2.2.4 which causes "Incompatible shapes" error, if any type of accuracy metric is used along with sparse_categorical_crossentropy. Use keras<=2.2.2 to use get validation accuracy.
)
model.summary()

# Train the model
model.fit(
    x=[np.asarray(encoder_inputs * 1000), np.asarray(decoder_inputs * 1000)],
    y=np.asarray(decoder_outputs * 1000),
    epochs=5,
)

Predict

decoded = decode(
    model,
    encoder_inputs_no_padding,
    start_token=token_dict['<START>'],
    end_token=token_dict['<END>'],
    pad_token=token_dict['<PAD>'],
    max_len=100,
)
token_dict_rev = {v: k for k, v in token_dict.items()}
for i in range(len(decoded)):
    print(' '.join(map(lambda x: token_dict_rev[x], decoded[i][1:-1])))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-transformer-0.19.0.tar.gz (7.0 kB view details)

Uploaded Source

File details

Details for the file keras-transformer-0.19.0.tar.gz.

File metadata

  • Download URL: keras-transformer-0.19.0.tar.gz
  • Upload date:
  • Size: 7.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.7.1 requests-toolbelt/0.8.0 tqdm/4.24.0 CPython/3.6.4

File hashes

Hashes for keras-transformer-0.19.0.tar.gz
Algorithm Hash digest
SHA256 e8a8cf2142f5106d258187327baac35277f11d532287b8461291bb1ea2e5b08b
MD5 c3e3516d7447780a1b72b2f1b159a97e
BLAKE2b-256 4d614ffb5d3f8fc50f1dd33132af5869f3779052f3e18b0829cc95d4ad2dce7d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page