Transformer-XL implemented in Keras
Project description
Keras Transformer-XL
Unofficial implementation of Transformer-XL.
Install
pip install keras-transformer-xl
Usage
Load Pretrained Weights
Several configuration files can be found at the info directory.
import os
from keras_transformer_xl import load_trained_model_from_checkpoint
checkpoint_path = 'foo/bar/sota/enwiki8'
model = load_trained_model_from_checkpoint(
config_path=os.path.join(checkpoint_path, 'config.json'),
checkpoint_path=os.path.join(checkpoint_path, 'model.ckpt')
)
model.summary()
About IO
The generated model has two inputs, and the second input is the lengths of memories.
You can use MemorySequence
wrapper for training and prediction:
import keras
import numpy as np
from keras_transformer_xl import MemorySequence, build_transformer_xl
class DummySequence(keras.utils.Sequence):
def __init__(self):
pass
def __len__(self):
return 10
def __getitem__(self, index):
return np.ones((3, 5 * (index + 1))), np.ones((3, 5 * (index + 1), 3))
model = build_transformer_xl(
units=4,
embed_dim=4,
hidden_dim=4,
num_token=3,
num_block=3,
num_head=2,
batch_size=3,
memory_len=20,
target_len=10,
)
seq = MemorySequence(
model=model,
sequence=DummySequence(),
target_len=10,
)
model.predict(model, seq, verbose=True)
Use tf.keras
Add TF_KERAS=1
to environment variables to use tensorflow.python.keras
.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Close
Hashes for keras-transformer-xl-0.12.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | a5c35c22b2150daf7dcffe8c203dda345dc938f0870bb19e0004b29dd37ae1ef |
|
MD5 | 3b7b1e336e27395eb23794acc60c5017 |
|
BLAKE2b-256 | fe322e8edd93cc3a687d5a63da5f81758e174aedfca948695caad6d9eb166845 |