Transformer-XL implemented in Keras
Project description
Keras Transformer-XL
Unofficial implementation of Transformer-XL.
Install
pip install keras-transformer-xl
Usage
Load Pretrained Weights
Several configuration files can be found at the info directory.
import os
from keras_transformer_xl import load_trained_model_from_checkpoint
checkpoint_path = 'foo/bar/sota/enwiki8'
model = load_trained_model_from_checkpoint(
config_path=os.path.join(checkpoint_path, 'config.json'),
checkpoint_path=os.path.join(checkpoint_path, 'model.ckpt')
)
model.summary()
About IO
The generated model has two inputs, and the second input is the lengths of memories.
You can use MemorySequence
wrapper for training and prediction:
from tensorflow import keras
import numpy as np
from keras_transformer_xl import MemorySequence, build_transformer_xl
class DummySequence(keras.utils.Sequence):
def __init__(self):
pass
def __len__(self):
return 10
def __getitem__(self, index):
return np.ones((3, 5 * (index + 1))), np.ones((3, 5 * (index + 1), 3))
model = build_transformer_xl(
units=4,
embed_dim=4,
hidden_dim=4,
num_token=3,
num_block=3,
num_head=2,
batch_size=3,
memory_len=20,
target_len=10,
)
seq = MemorySequence(
model=model,
sequence=DummySequence(),
target_len=10,
)
model.predict(model, seq, verbose=True)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file keras-transformer-xl-0.14.0.tar.gz
.
File metadata
- Download URL: keras-transformer-xl-0.14.0.tar.gz
- Upload date:
- Size: 14.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.2 CPython/3.7.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e49661f0ee6d963bfe37743c3dd91fcf767e1852d7244465b1d5d4b55b357253 |
|
MD5 | 7374674ce42a756413d6cf651967cd9f |
|
BLAKE2b-256 | a817a15fc07a0d78d687bb5d5d77bee86a430b95afa73f1981eb3bb1a010d26e |