Skip to main content

XLNet implemented in Keras

Project description

Keras XLNet

Version License

[中文|English]

Unofficial implementation of XLNet. Embedding extraction and embedding extract with memory show how to get the outputs of the last transformer layer using pre-trained checkpoints.

Install

pip install keras-xlnet

Usage

Fine-tuning on GLUE

Click the task name to see the demos with base model:

Task Name Metrics Approximate Results on Dev Set
CoLA Matthew Corr. 52
SST-2 Accuracy 93
MRPC Accuracy/F1 86/89
STS-B Pearson Corr. / Spearman Corr. 86/87
QQP Accuracy/F1 90/86
MNLI Accuracy 84/84
QNLI Accuracy 86
RTE Accuracy 64
WNLI Accuracy 56

(Only 0s are predicted in WNLI dataset)

Load Pretrained Checkpoints

import os
from keras_xlnet import Tokenizer, load_trained_model_from_checkpoint, ATTENTION_TYPE_BI

checkpoint_path = '.../xlnet_cased_L-24_H-1024_A-16'

tokenizer = Tokenizer(os.path.join(checkpoint_path, 'spiece.model'))
model = load_trained_model_from_checkpoint(
    config_path=os.path.join(checkpoint_path, 'xlnet_config.json'),
    checkpoint_path=os.path.join(checkpoint_path, 'xlnet_model.ckpt'),
    batch_size=16,
    memory_len=512,
    target_len=128,
    in_train_phase=False,
    attention_type=ATTENTION_TYPE_BI,
)
model.summary()

Arguments batch_size, memory_len and target_len are maximum sizes used for initialization of memories. The model used for training a language model is returned if in_train_phase is True, otherwise a model used for fine-tuning will be returned.

About I/O

Note that shuffle should be False in either fit or fit_generator if memories are used.

in_train_phase is False

3 inputs:

  • IDs of tokens, with shape (batch_size, target_len).
  • IDs of segments, with shape (batch_size, target_len).
  • Length of memories, with shape (batch_size, 1).

1 output:

  • The feature for each token, with shape (batch_size, target_len, units).

in_train_phase is True

4 inputs:

  • IDs of tokens, with shape (batch_size, target_len).
  • IDs of segments, with shape (batch_size, target_len).
  • Length of memories, with shape (batch_size, 1).
  • Masks of tokens, with shape (batch_size, target_len).

1 output:

  • The probability of each token in each position, with shape (batch_size, target_len, num_token).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-xlnet-0.20.0.tar.gz (20.7 kB view details)

Uploaded Source

File details

Details for the file keras-xlnet-0.20.0.tar.gz.

File metadata

  • Download URL: keras-xlnet-0.20.0.tar.gz
  • Upload date:
  • Size: 20.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.2 CPython/3.7.4

File hashes

Hashes for keras-xlnet-0.20.0.tar.gz
Algorithm Hash digest
SHA256 b9e46404b33cd4b89e2281fd97d34acd1483fec5331b3341b1e17175e3b34df9
MD5 b36915f0a709788084f0b365ec3c0306
BLAKE2b-256 a9d2e62648c98d39857d895f85f0dcf58c06ff426c32fca2b1630e620ca66a48

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page