Skip to main content

用tf.keras (TF2.0+) 的稳定API实现NLP预训练模型,例如BERT、BART等。

Project description

README

动机

  • 用tf.keras (TF2.0+) 的稳定API实现NLP预训练模型,例如BERT、BART等。
  • 不做过多的自定义类、方法,力图代码简洁,易懂,易扩展。

支持的模型

  • BERT
  • BART

使用例子

安装

pip install pretrain4keras

BERT

from pretrain4keras.models.bert import BertBuilder

# 0.下载参数,存放于bert_dir下
# Google原版bert: https://github.com/google-research/bert
bert_dir = "/Users/mos_luo/project/pretrain_model/bert/chinese_L-12_H-768_A-12/"
config_file = bert_dir + "bert_config.json"
checkpoint_file = bert_dir + "bert_model.ckpt"
vocab_file = bert_dir + "vocab.txt"

# 1.创建keras bert模型与tokenizer
keras_bert, tokenizer, config = BertBuilder().build_bert(
    config_file=config_file, checkpoint_file=checkpoint_file, vocab_file=vocab_file
)

# 2.创建输入样本
# tokenizer = builder.tokenizer(vocab_file)
inputs = tokenizer(["语言模型"], return_tensors="tf")
print(keras_bert(inputs))

BART

import pprint
import tensorflow as tf
from pretrain4keras.models.bart import BartBuilder

# 0.手动从fnlp/bart-base-chinese下载文件
# 从https://huggingface.co/fnlp/bart-base-chinese/tree/4e93f21dca95a07747f434b0f9fe5d49cacc0441下载文件夹的所有文件
pretrain_dir = "/Users/normansluo/project/pretrain_model/huggingface_transformers/fnlp/bart-base-chinese-v2/"
checkpoint_file = pretrain_dir + "pytorch_model.bin"
config_file = pretrain_dir + "config.json"
vocab_file = pretrain_dir + "vocab.txt"

# 1.创建keras bart模型
builder = BartBuilder()
keras_bart, tokenizer, config = builder.build_bart(
    config_file=config_file, checkpoint_file=checkpoint_file, vocab_file=vocab_file
)

# 2.创建输入样本
inputs = tokenizer(["北京是[MASK]的首都"], return_tensors="tf")
del inputs["token_type_ids"]
inputs["decoder_input_ids"] = tf.constant(
    [[102, 101, 6188, 5066, 11009, 4941, 7178, 15134, 23943, 21784]]
)
pprint.pprint(inputs)

# 3.keras bart的输出
print("=========== keras bart的输出 ============>")
keras_bart_out = keras_bart(inputs)
print("keras_bart_out=")
print(keras_bart_out)
print(tokenizer.batch_decode(tf.argmax(keras_bart_out["lm"], axis=2).numpy()))

requirements

  • python>=3.6
  • tensorflow>=2.0.0
  • numpy
  • transformers=4.25.1
    • 主要是为了提供tokenizer,不是必须的,可以不装。
    • 你也可以用其他的tokenizer实现。

参考

更新日志

  • 2023.01.15:添加BART
  • 2023.01.30:添加BERT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pretrain4keras-0.0.8.tar.gz (15.4 kB view details)

Uploaded Source

Built Distribution

pretrain4keras-0.0.8-py3-none-any.whl (18.1 kB view details)

Uploaded Python 3

File details

Details for the file pretrain4keras-0.0.8.tar.gz.

File metadata

  • Download URL: pretrain4keras-0.0.8.tar.gz
  • Upload date:
  • Size: 15.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.7.6

File hashes

Hashes for pretrain4keras-0.0.8.tar.gz
Algorithm Hash digest
SHA256 781b82048590ffadeec3281dcbee7ffc8c39725d5b3bc7dd67352ba5294ac97a
MD5 0f6f27d539e1db3e024078b7fafe663e
BLAKE2b-256 567e304429b402fc881f774d6b38cc27492337c871378b8bddd7a76470c964f5

See more details on using hashes here.

File details

Details for the file pretrain4keras-0.0.8-py3-none-any.whl.

File metadata

  • Download URL: pretrain4keras-0.0.8-py3-none-any.whl
  • Upload date:
  • Size: 18.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.7.6

File hashes

Hashes for pretrain4keras-0.0.8-py3-none-any.whl
Algorithm Hash digest
SHA256 3d0b9b2095ac47a3527bc503731e0cdffc45d638238d01e759f31979ef157bdf
MD5 aff1fdd6d9a7f3ca52b13ff1545abc81
BLAKE2b-256 26b02e47edfa901a494a6cb0ca56ada216e3a456d6d1a55c8274a2d28524e4f4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page