Transformer-XL implemented in Keras
Project description
Keras Transformer-XL
Unofficial implementation of Transformer-XL.
Install
pip install keras-transformer-xl
Usage
Load Pretrained Weights
Several configuration files can be found at the info directory.
import os
from keras_transformer_xl import load_trained_model_from_checkpoint
checkpoint_path = 'foo/bar/sota/enwiki8'
model = load_trained_model_from_checkpoint(
config_path=os.path.join(checkpoint_path, 'config.json'),
checkpoint_path=os.path.join(checkpoint_path, 'model.ckpt')
)
model.summary()
About IO
Suppose the number of transformer blocks is n
. The last n
inputs are used for inputs of memorization, and the last n
outputs represents new data to be memorized.
You can use MemorySequence
wrapper for training and prediction:
import keras
import numpy as np
from keras_transformer_xl import MemorySequence, build_transformer_xl, fit_generator, predict_generator
class DummySequence(keras.utils.Sequence):
def __init__(self):
pass
def __len__(self):
return 10
def __getitem__(self, index):
return np.ones((3, 5 * (index + 1))), np.ones((3, 5 * (index + 1), 3))
model = build_transformer_xl(
units=4,
embed_dim=4,
hidden_dim=4,
num_token=3,
num_block=3,
num_head=2,
)
seq = MemorySequence(
units=4,
model=model,
sequence=DummySequence(),
target_len=10,
memory_len=20,
)
fit_generator(model, seq, epochs=2, validation_data=seq)
predict_generator(model, seq, verbose=True)
Use tensorflow.python.keras
Add TF_KERAS=1
to environment variables to use tensorflow.python.keras
.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
keras-transformer-xl-0.4.0.tar.gz
(16.6 kB
view hashes)
Close
Hashes for keras-transformer-xl-0.4.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 22379366aa160885735c374d66289e1ea4a1335bc3111ca9b80e098298c04993 |
|
MD5 | fd12023521452649acdfa8b3602bc154 |
|
BLAKE2b-256 | fb8e667803d88b7f6b0950f00db07d9a050032d6f27f8199f6ce19c5fc08f01a |