Skip to main content

Make it easy to user BERT models.

Project description

bert_utils: a text embedding library

A standardized interface to make it easy to integrate BERT language models in neural networks. Built on top of transformer from Hugging Face, and tensorflow hub.

Usage in tensorflow Estimator

import tensorflow as tf
from bert_utils.text_embedding import TFHubTextEmbedding

embedding_model = TFHubTextEmbedding()

def model_fn(features, labels, params, config, mode=tf.estimator.ModeKeys.TRAIN):
    x = embed_tokenized(features[nlp_feature_name])
    predictions = tf.keras.layers.Dense(1, activation="sigmoid")(x)
    loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)(
        labels, predictions
    )
    eval_metric_ops = {
        "accuracy": tf.compat.v1.metrics.auc(labels, predictions),
    }
    optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.01)
    train_op = optimizer.minimize(
        loss=loss, global_step=tf.compat.v1.train.get_global_step()
    )
    return tf.estimator.EstimatorSpec(
        mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops
    )

estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir="model_output", params={})

Usage in keras Model

import tensorflow as tf
from tensorflow import keras
from bert_utils.text_embedding import TransformerTextEmbedding

embedding_model = TransformerTextEmbedding()

input_ids = keras.layers.Input(shape=(32,), dtype=tf.int32)
embedding_features = embedding_model.embed_tokenized(input_ids)
logits = keras.layers.Dense(1, activation)
model = keras.Model(inputs=input_ids, outputs=logits)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for bert-utils, version 0.1.2
Filename, size File type Python version Upload date Hashes
Filename, size bert_utils-0.1.2-py3-none-any.whl (34.9 kB) File type Wheel Python version py3 Upload date Hashes View hashes
Filename, size bert_utils-0.1.2.tar.gz (31.0 kB) File type Source Python version None Upload date Hashes View hashes

Supported by

Elastic Elastic Search Pingdom Pingdom Monitoring Google Google BigQuery Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page