Transformer-XL implemented in Keras
Project description
Keras Transformer-XL
Unofficial implementation of Transformer-XL.
Install
pip install keras-transformer-xl
Usage
Load Pretrained Weights
Several configuration files can be found at the info directory.
import os
from keras_transformer_xl import load_trained_model_from_checkpoint
checkpoint_path = 'foo/bar/sota/enwiki8'
model = load_trained_model_from_checkpoint(
config_path=os.path.join(checkpoint_path, 'config.json'),
checkpoint_path=os.path.join(checkpoint_path, 'model.ckpt')
)
model.summary()
About IO
The generated model has two inputs, and the second input is the lengths of memories.
You can use MemorySequence
wrapper for training and prediction:
from tensorflow import keras
import numpy as np
from keras_transformer_xl import MemorySequence, build_transformer_xl
class DummySequence(keras.utils.Sequence):
def __init__(self):
pass
def __len__(self):
return 10
def __getitem__(self, index):
return np.ones((3, 5 * (index + 1))), np.ones((3, 5 * (index + 1), 3))
model = build_transformer_xl(
units=4,
embed_dim=4,
hidden_dim=4,
num_token=3,
num_block=3,
num_head=2,
batch_size=3,
memory_len=20,
target_len=10,
)
seq = MemorySequence(
model=model,
sequence=DummySequence(),
target_len=10,
)
model.predict(model, seq, verbose=True)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Close
Hashes for keras-transformer-xl-0.14.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | e49661f0ee6d963bfe37743c3dd91fcf767e1852d7244465b1d5d4b55b357253 |
|
MD5 | 7374674ce42a756413d6cf651967cd9f |
|
BLAKE2b-256 | a817a15fc07a0d78d687bb5d5d77bee86a430b95afa73f1981eb3bb1a010d26e |