用tf.keras (TF2.0+) 的稳定API实现NLP预训练模型,例如BERT、BART等。
Project description
README
动机
- 用tf.keras (TF2.0+) 的稳定API实现NLP预训练模型,例如BERT、BART等。
- 不做过多的自定义类、方法,力图代码简洁,易懂,易扩展。
支持的模型
- BERT
- BART
使用例子
安装
pip install pretrain4keras
BERT
- bert参数下载:
- Google原版bert: https://github.com/google-research/bert
- 代码
from pretrain4keras.models.bert import BertBuilder
# 0.下载参数,存放于bert_dir下
# Google原版bert: https://github.com/google-research/bert
bert_dir = "/Users/mos_luo/project/pretrain_model/bert/chinese_L-12_H-768_A-12/"
config_file = bert_dir + "bert_config.json"
checkpoint_file = bert_dir + "bert_model.ckpt"
vocab_file = bert_dir + "vocab.txt"
# 1.创建keras bert模型与tokenizer
keras_bert, tokenizer, config = BertBuilder().build_bert(
config_file=config_file, checkpoint_file=checkpoint_file, vocab_file=vocab_file
)
# 2.创建输入样本
# tokenizer = builder.tokenizer(vocab_file)
inputs = tokenizer(["语言模型"], return_tensors="tf")
print(keras_bert(inputs))
BART
- BART参数下载
- 例如从https://huggingface.co/fnlp/bart-base-chinese/tree/4e93f21dca95a07747f434b0f9fe5d49cacc0441下载文件夹的所有文件
- huggingface transformers的参数与json的可能会发现变动,所以下载时需要指定版本id,例如4e93f21dca95a07747f434b0f9fe5d49cacc0441
- 示例代码
import pprint
import tensorflow as tf
from pretrain4keras.models.bart import BartBuilder
# 0.手动从fnlp/bart-base-chinese下载文件
# 从https://huggingface.co/fnlp/bart-base-chinese/tree/4e93f21dca95a07747f434b0f9fe5d49cacc0441下载文件夹的所有文件
pretrain_dir = "/Users/normansluo/project/pretrain_model/huggingface_transformers/fnlp/bart-base-chinese-v2/"
checkpoint_file = pretrain_dir + "pytorch_model.bin"
config_file = pretrain_dir + "config.json"
vocab_file = pretrain_dir + "vocab.txt"
# 1.创建keras bart模型
builder = BartBuilder()
keras_bart, tokenizer, config = builder.build_bart(
config_file=config_file, checkpoint_file=checkpoint_file, vocab_file=vocab_file
)
# 2.创建输入样本
inputs = tokenizer(["北京是[MASK]的首都"], return_tensors="tf")
del inputs["token_type_ids"]
inputs["decoder_input_ids"] = tf.constant(
[[102, 101, 6188, 5066, 11009, 4941, 7178, 15134, 23943, 21784]]
)
pprint.pprint(inputs)
# 3.keras bart的输出
print("=========== keras bart的输出 ============>")
keras_bart_out = keras_bart(inputs)
print("keras_bart_out=")
print(keras_bart_out)
print(tokenizer.batch_decode(tf.argmax(keras_bart_out["lm"], axis=2).numpy()))
requirements
- python>=3.6
- tensorflow>=2.0.0
- numpy
- transformers=4.25.1
- 主要是为了提供tokenizer,不是必须的,可以不装。
- 你也可以用其他的tokenizer实现。
参考
更新日志
- 2023.01.15:添加BART
- 2023.01.30:添加BERT
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pretrain4keras-0.0.8.tar.gz
(15.4 kB
view details)
Built Distribution
File details
Details for the file pretrain4keras-0.0.8.tar.gz
.
File metadata
- Download URL: pretrain4keras-0.0.8.tar.gz
- Upload date:
- Size: 15.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.7.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
781b82048590ffadeec3281dcbee7ffc8c39725d5b3bc7dd67352ba5294ac97a
|
|
MD5 |
0f6f27d539e1db3e024078b7fafe663e
|
|
BLAKE2b-256 |
567e304429b402fc881f774d6b38cc27492337c871378b8bddd7a76470c964f5
|
File details
Details for the file pretrain4keras-0.0.8-py3-none-any.whl
.
File metadata
- Download URL: pretrain4keras-0.0.8-py3-none-any.whl
- Upload date:
- Size: 18.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.7.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
3d0b9b2095ac47a3527bc503731e0cdffc45d638238d01e759f31979ef157bdf
|
|
MD5 |
aff1fdd6d9a7f3ca52b13ff1545abc81
|
|
BLAKE2b-256 |
26b02e47edfa901a494a6cb0ca56ada216e3a456d6d1a55c8274a2d28524e4f4
|