Skip to main content

XLNet implemented in Keras

Project description

Keras XLNet

Version License


Unofficial implementation of XLNet. Embedding extraction and embedding extract with memory show how to get the outputs of the last transformer layer using pre-trained checkpoints.


pip install keras-xlnet


Fine-tuning on GLUE

Click the task name to see the demos with base model:

Task Name Metrics Approximate Results on Dev Set
CoLA Matthew Corr. 52
SST-2 Accuracy 93
MRPC Accuracy/F1 86/89
STS-B Pearson Corr. / Spearman Corr. 86/87
QQP Accuracy/F1 90/86
MNLI Accuracy 84/84
QNLI Accuracy 86
RTE Accuracy 64
WNLI Accuracy 56

(Only 0s are predicted in WNLI dataset)

Load Pretrained Checkpoints

import os
from keras_xlnet import Tokenizer, load_trained_model_from_checkpoint, ATTENTION_TYPE_BI

checkpoint_path = '.../xlnet_cased_L-24_H-1024_A-16'

tokenizer = Tokenizer(os.path.join(checkpoint_path, 'spiece.model'))
model = load_trained_model_from_checkpoint(
    config_path=os.path.join(checkpoint_path, 'xlnet_config.json'),
    checkpoint_path=os.path.join(checkpoint_path, 'xlnet_model.ckpt'),

Arguments batch_size, memory_len and target_len are maximum sizes used for initialization of memories. The model used for training a language model is returned if in_train_phase is True, otherwise a model used for fine-tuning will be returned.

About I/O

Note that shuffle should be False in either fit or fit_generator if memories are used.

in_train_phase is False

3 inputs:

  • IDs of tokens, with shape (batch_size, target_len).
  • IDs of segments, with shape (batch_size, target_len).
  • Length of memories, with shape (batch_size, 1).

1 output:

  • The feature for each token, with shape (batch_size, target_len, units).

in_train_phase is True

4 inputs:

  • IDs of tokens, with shape (batch_size, target_len).
  • IDs of segments, with shape (batch_size, target_len).
  • Length of memories, with shape (batch_size, 1).
  • Masks of tokens, with shape (batch_size, target_len).

1 output:

  • The probability of each token in each position, with shape (batch_size, target_len, num_token).

Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-xlnet-0.20.0.tar.gz (20.7 kB view hashes)

Uploaded source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page