XLNet implemented in Keras
Project description
Keras XLNet
Unofficial implementation of XLNet.
Install
pip install keras-xlnet
Usage
Load Pretrained Checkpoints
import os
from keras_xlnet import load_trained_model_from_checkpoint
checkpoint_path = '.../xlnet_cased_L-24_H-1024_A-16'
model = load_trained_model_from_checkpoint(
config_path=os.path.join(checkpoint_path, 'xlnet_config.json'),
checkpoint_path=os.path.join(checkpoint_path, 'xlnet_model.ckpt'),
batch_size=16,
memory_len=512,
target_len=128,
in_train_phase=False,
)
model.summary()
Arguments batch_size
, memory_len
and target_len
are maximum sizes used for initialization of memories. The model used for training a language model is returned if in_train_phase
is True
, otherwise a model used for fine-tuning will be returned.
About I/O
in_train_phase
is False
3 inputs:
- IDs of tokens, with shape
(batch_size, target_len)
. - IDs of segments, with shape
(batch_size, target_len)
. - Length of memories, with shape
(batch_size, 1)
.
1 output:
- The feature for each token, with shape
(batch_size, target_len, units)
.
in_train_phase
is True
4 inputs:
- IDs of tokens, with shape
(batch_size, target_len)
. - IDs of segments, with shape
(batch_size, target_len)
. - Length of memories, with shape
(batch_size, 1)
. - Masks of tokens, with shape
(batch_size, target_len)
.
1 output:
- The probability of each token in each position, with shape
(batch_size, target_len, num_token)
.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
keras-xlnet-0.0.3.tar.gz
(11.2 kB
view details)
File details
Details for the file keras-xlnet-0.0.3.tar.gz
.
File metadata
- Download URL: keras-xlnet-0.0.3.tar.gz
- Upload date:
- Size: 11.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.4.2 requests/2.18.4 setuptools/41.0.1 requests-toolbelt/0.8.0 tqdm/4.24.0 CPython/3.6.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
279921effed38f44e0296acd159e4f06722147fb649882ba44a6da96b034764c
|
|
MD5 |
b0ed80e7814f1b3e8640e97c4e7d5132
|
|
BLAKE2b-256 |
387e7c6de4f5a31158f422f6b6edf9efec90d3c3d82b4215aecf2bae57ab89bf
|