Skip to main content

Tools for adapting universal language models to specifc tasks

Project description


Tools for adapting universal language models to specifc tasks

Please note these are tools for rapid prototyping - not brute force hyperparameter tuning.

Adapted from: Google's BERT


  1. Clone repo (just for now - will be on pip soon)
  2. Download a pretrained BERT model - start with BERT-Base Uncased if you're not sure where to begin
  3. Unzip the model and make note of the path


Current Capabilities

BERT for Question Answering

BERT for Multi-Label Classification

BERT for Single-Label Classification

Future Capabilities

GPT-2 Training and Generation

Usage + Core Components

Configuring BERT

First, define the model paths

base_model_folder_path = "../models/uncased_L-12_H-768_A-12/"  # Folder containing downloaded Base Model
name_of_config_json_file = "bert_config.json"  # Located inside the Base Model folder
name_of_vocab_file = "vocab.txt"  # Located inside the Base Model folder

output_directory = "../models/trained_BERT/" # Trained model and results landing folder

# Multi-Label and Single-Label Specific
data_dir = None  # Directory .tsv data is stored in - typically for CoLA/MPRC or other datasets with known structure

Second, define the model run parameters

"""Settable parameters and their default values

Note: Most default values are perfectly fine

# Administrative
init_checkpoint = None
save_checkpoints_steps = 1000
iterations_per_loop = 1000
do_lower_case = True   

# Technical
batch_size_train = 32
batch_size_eval = 8
batch_size_predict = 8
num_train_epochs = 3.0
max_seq_length = 128
warmup_proportion = 0.1
learning_rate = 3e-5

# SQuAD Specific
doc_stride = 128
max_query_length = 64
n_best_size = 20
max_answer_length = 30
is_squad_v2 = False  # SQuAD 2.0 has examples with no answer, aka "impossible", SQuAD 1.0 does not
verbose_logging = False
null_score_diff_threshold = 0.0

Initialize the configuration handler

from critical_path.BERT.configs import ConfigClassifier

Flags = ConfigClassifier()
    bert_config_file=base_model_folder_path + name_of_config_json_file,
    bert_vocab_file=base_model_folder_path + name_of_vocab_file,


# Retrieve a handle for the configs
FLAGS = Flags.get_handle()

A single 1070GTX using BERT-Base Uncased can handle

Model max_seq_len batch_size
BERT-Base Uncased 256 6
... 384 4

For full batch size and sequence length guidelines see Google's recommendations

Using Configured Model

First, create a new model with the configured parameters

"""For Multi-Label Classification"""
from critical_path.BERT.model_multilabel_class import MultiLabelClassifier

model = MultiLabelClassifier(FLAGS)

Second, load your data source

  • SQuAD has dedicated dataloaders
  • Multi-Label Classification has a generic dataloader
    • DataProcessor in /BERT/model_multilabel_class
      • Note: This requires data labels to be in string format
      • labels = [
            ["label_1", "label_2", "label_3"],
  • Single-Label Classification dataloaders
"""For Multi-Label Classification with a custom .csv reading function"""
from critical_path.BERT.model_multilabel_class import DataProcessor

# read_data is dataset specifc - see /
input_ids, input_text, input_labels, label_list = read_toxic_data(randomize=True)

processor = DataProcessor(label_list=label_list)
train_examples = processor.get_samples(

Third, run your task

"""Train and predict a Multi-Label Classifier"""

if do_train:
  model.train(train_examples, label_list)

if do_predict:
  model.predict(predict_examples, label_list)

For full examples please see:

Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for critical-path, version 0.0.9
Filename, size File type Python version Upload date Hashes
Filename, size critical_path-0.0.9.tar.gz (40.1 kB) File type Source Python version None Upload date Hashes View
Filename, size critical_path-0.0.9-py3-none-any.whl (53.4 kB) File type Wheel Python version py3 Upload date Hashes View

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Huawei Huawei PSF Sponsor Microsoft Microsoft PSF Sponsor NVIDIA NVIDIA PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page