XLNet implemented in Keras
Project description
Keras XLNet
Unofficial implementation of XLNet. Embedding extraction and embedding extract with memory show how to get the outputs of the last transformer layer using pre-trained checkpoints.
Install
pip install keras-xlnet
Usage
Fine-tuning on GLUE
Click the task name to see the demos:
Task Name | Metrics | Approximate Results on Dev Set |
---|---|---|
CoLA | Matthew Corr. | 52 |
SST-2 | Accuracy | 93 |
MRPC | Accuracy/F1 | 86/89 |
STS-B | Pearson Corr. / Spearman Corr. | 86/87 |
QQP | Accuracy/F1 | 90/86 |
MNLI | Accuracy | 84/84 |
QNLI | Accuracy | 86 |
RTE | Accuracy | 64 |
WNLI | Accuracy | 56 |
(Only 0s are predicted in WNLI dataset)
Load Pretrained Checkpoints
import os
from keras_xlnet import Tokenizer, load_trained_model_from_checkpoint
checkpoint_path = '.../xlnet_cased_L-24_H-1024_A-16'
tokenizer = Tokenizer(os.path.join(checkpoint_path, 'spiece.model'))
model = load_trained_model_from_checkpoint(
config_path=os.path.join(checkpoint_path, 'xlnet_config.json'),
checkpoint_path=os.path.join(checkpoint_path, 'xlnet_model.ckpt'),
batch_size=16,
memory_len=512,
target_len=128,
in_train_phase=False,
)
model.summary()
Arguments batch_size
, memory_len
and target_len
are maximum sizes used for initialization of memories. The model used for training a language model is returned if in_train_phase
is True
, otherwise a model used for fine-tuning will be returned.
About I/O
Note that shuffle
should be False
in either fit
or fit_generator
if memories are used.
in_train_phase
is False
3 inputs:
- IDs of tokens, with shape
(batch_size, target_len)
. - IDs of segments, with shape
(batch_size, target_len)
. - Length of memories, with shape
(batch_size, 1)
.
1 output:
- The feature for each token, with shape
(batch_size, target_len, units)
.
in_train_phase
is True
4 inputs:
- IDs of tokens, with shape
(batch_size, target_len)
. - IDs of segments, with shape
(batch_size, target_len)
. - Length of memories, with shape
(batch_size, 1)
. - Masks of tokens, with shape
(batch_size, target_len)
.
1 output:
- The probability of each token in each position, with shape
(batch_size, target_len, num_token)
.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.