Skip to main content

bert for tensorflow2

Project description

This repo contains a TensorFlow 2.0 Keras implementation of google-research/bert with support for loading of the original pre-trained weights, and producing activations numerically identical to the one calculated by the original model.

ALBERT and adapter-BERT are also supported by setting the corresponding configuration parameters (shared_layer=True, embedding_size for ALBERT and adapter_size for adapter-BERT). Setting both will result in an adapter-ALBERT by sharing the BERT parameters across all layers while adapting every layer with layer specific adapter.

The implementation is build from scratch using only basic tensorflow operations, following the code in google-research/bert/ (but skipping dead code and applying some simplifications). It also utilizes kpe/params-flow to reduce common Keras boilerplate code (related to passing model and layer configuration arguments).

bert4tf should work with both TensorFlow 2.0 and TensorFlow 1.14 or newer.



MIT. See License File.


bert4tf is on the Python Package Index (PyPI):

pip install bert4tf


BERT in bert4tf is implemented as a Keras layer. You could instantiate it like this:

from bert4tf import BertModelLayer

l_bert = BertModelLayer(**BertModelLayer.Params(
  vocab_size               = 16000,        # embedding params
  use_token_type           = True,
  use_position_embeddings  = True,
  token_type_vocab_size    = 2,

  num_layers               = 12,           # transformer encoder params
  hidden_size              = 768,
  hidden_dropout           = 0.1,
  intermediate_size        = 4*768,
  intermediate_activation  = "gelu",

  adapter_size             = None,         # see arXiv:1902.00751 (adapter-BERT)

  shared_layer             = False,        # True for ALBERT (arXiv:1909.11942)
  embedding_size           = None,         # None for BERT, wordpiece embedding size for ALBERT

  name                     = "bert"        # any other Keras layer params

or by using the bert_config.json from a pre-trained google model:

import bert4tf

model_dir = ".models/uncased_L-12_H-768_A-12"

bert_params = bert.params_from_pretrained_ckpt(model_dir)
l_bert = bert4tf.BertModelLayer.from_params(bert_params, name="bert")

now you can use the BERT layer in your Keras model like this:

from tensorflow import keras

max_seq_len = 128
l_input_ids      = keras.layers.Input(shape=(max_seq_len,), dtype='int32')
l_token_type_ids = keras.layers.Input(shape=(max_seq_len,), dtype='int32')

# using the default token_type/segment id 0
output = l_bert(l_input_ids)                              # output: [batch_size, max_seq_len, hidden_size]
model = keras.Model(inputs=l_input_ids, outputs=output), max_seq_len))

# provide a custom token_type/segment id as a layer input
output = l_bert([l_input_ids, l_token_type_ids])          # [batch_size, max_seq_len, hidden_size]
model = keras.Model(inputs=[l_input_ids, l_token_type_ids], outputs=output)[(None, max_seq_len), (None, max_seq_len)])

if you choose to use adapter-BERT by setting the adapter_size parameter, you would also like to freeze all the original BERT layers by calling:


and once the model has been build or compiled, the original pre-trained weights can be loaded in the BERT layer:

import bert4tf

bert_ckpt_file   = os.path.join(model_dir, "bert_model.ckpt")
bert.load_stock_weights(l_bert, bert_ckpt_file)

N.B. see tests/ for a complete example.


  1. How to use BERT with the google-research/bert pre-trained weights?
model_name = "uncased_L-12_H-768_A-12"
model_dir = bert4tf.fetch_google_bert_model(model_name, ".models")
model_ckpt = os.path.join(model_dir, "bert_model.ckpt")

bert_params = bert4tf.params_from_pretrained_ckpt(model_dir)
l_bert = bert4tf.BertModelLayer.from_params(bert_params, name="bert")

# use in Keras Model here, and call

bert.load_bert_weights(l_bert, model_ckpt)      # should be called after
  1. How to use ALBERT with the google-research/albert pre-trained weights?
model_name = "albert_base"
model_dir    = bert4tf.fetch_tfhub_albert_model(model_name, ".models")
model_params = bert4tf.albert_params(model_name)
l_bert = bert.BertModelLayer.from_params(model_params, name="albert")

# use in Keras Model here, and call

bert.load_albert_weights(l_bert, albert_dir)      # should be called after
  1. How to use ALBERT with the albert_zh pre-trained weights?

see tests/nonci/

model_name = "albert_base"
model_dir = bert4tf.fetch_brightmart_albert_model(model_name, ".models")
model_ckpt = os.path.join(model_dir, "albert_model.ckpt")

bert_params = bert.params_from_pretrained_ckpt(model_dir)
l_bert = bert.BertModelLayer.from_params(bert_params, name="bert")

# use in a Keras Model here, and call

bert.load_albert_weights(l_bert, model_ckpt)      # should be called after
  1. How to tokenize the input for the google-research/bert models?
do_lower_case = not (model_name.find(“cased”) == 0 or model_name.find(“multi_cased”) == 0) bert.bert_tokenization.validate_case_matches_checkpoint(do_lower_case, model_ckpt) vocab_file = os.path.join(model_dir, “vocab.txt”) tokenizer = bert4tf.bert_tokenization.FullTokenizer(vocab_file, do_lower_case) tokens = tokenizer.tokenize(“Hello, BERT-World!”) token_ids = tokenizer.convert_tokens_to_ids(tokens)
  1. How to tokenize the input for the google-research/albert models?

import sentencepiece as spm

spm_model = os.path.join(model_dir, “assets”, “30k-clean.model”) sp = spm.SentencePieceProcessor() sp.load(spm_model) do_lower_case = True

processed_text = bert.albert_tokenization.preprocess_text(“Hello, World!”, lower=do_lower_case) token_ids = bert4tf.albert_tokenization.encode_ids(sp, processed_text)

Referrence — 1. kpe


Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for bert4tf, version 2.0.2
Filename, size File type Python version Upload date Hashes
Filename, size bert4tf-2.0.2-py3-none-any.whl (38.8 kB) File type Wheel Python version py3 Upload date Hashes View

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring DigiCert DigiCert EV certificate Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page