Transformer-XL implemented in Keras
Project description
Keras Transformer-XL
Unofficial implementation of Transformer-XL.
Install
pip install keras-transformer-xl
Usage
Load Pretrained Weights
Several configuration files can be found at the info directory.
import os
from keras_transformer_xl import load_trained_model_from_checkpoint
checkpoint_path = 'foo/bar/sota/enwiki8'
model = load_trained_model_from_checkpoint(
config_path=os.path.join(checkpoint_path, 'config.json'),
checkpoint_path=os.path.join(checkpoint_path, 'model.ckpt')
)
model.summary()
About IO
The generated model has two inputs, and the second input is the lengths of memories.
You can use MemorySequence
wrapper for training and prediction:
import keras
import numpy as np
from keras_transformer_xl import MemorySequence, build_transformer_xl
class DummySequence(keras.utils.Sequence):
def __init__(self):
pass
def __len__(self):
return 10
def __getitem__(self, index):
return np.ones((3, 5 * (index + 1))), np.ones((3, 5 * (index + 1), 3))
model = build_transformer_xl(
units=4,
embed_dim=4,
hidden_dim=4,
num_token=3,
num_block=3,
num_head=2,
batch_size=3,
memory_len=20,
target_len=10,
)
seq = MemorySequence(
model=model,
sequence=DummySequence(),
target_len=10,
)
model.predict(model, seq, verbose=True)
Use tf.keras
Add TF_KERAS=1
to environment variables to use tensorflow.python.keras
.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Close
Hashes for keras-transformer-xl-0.11.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 621be3fef1668ac391207e04d2dd960550b0c782bf6dec3359d2a9b30a420506 |
|
MD5 | bd55b6e0f4eb67b51a7807ba0d057897 |
|
BLAKE2b-256 | f767a0d30945ed5dc272a2a5d376b828a161ef507898cff59bd5f2cdb21d112f |