Skip to main content

Tools for adapting universal language models to specifc tasks

Project description

critical-path-nlp

Tools for adapting universal language models to specifc tasks

Please note these are tools for rapid prototyping - not brute force hyperparameter tuning.

Adapted from: Google's BERT

Installation

  1. Clone repo (just for now - will be on pip soon)
  2. Download a pretrained BERT model - start with BERT-Base Uncased if you're not sure where to begin
  3. Unzip the model and make note of the path

Examples

Current Capabilities

BERT for Question Answering

BERT for Multi-Label Classification

BERT for Single-Label Classification

Future Capabilities

GPT-2 Training and Generation

Usage + Core Components

Configuring BERT

First, define the model paths

base_model_folder_path = "../models/uncased_L-12_H-768_A-12/"  # Folder containing downloaded Base Model
name_of_config_json_file = "bert_config.json"  # Located inside the Base Model folder
name_of_vocab_file = "vocab.txt"  # Located inside the Base Model folder

output_directory = "../models/trained_BERT/" # Trained model and results landing folder

# Multi-Label and Single-Label Specific
data_dir = None  # Directory .tsv data is stored in - typically for CoLA/MPRC or other datasets with known structure

Second, define the model run parameters

"""Settable parameters and their default values

Note: Most default values are perfectly fine
"""

# Administrative
init_checkpoint = None
save_checkpoints_steps = 1000
iterations_per_loop = 1000
do_lower_case = True   

# Technical
batch_size_train = 32
batch_size_eval = 8
batch_size_predict = 8
num_train_epochs = 3.0
max_seq_length = 128
warmup_proportion = 0.1
learning_rate = 3e-5

# SQuAD Specific
doc_stride = 128
max_query_length = 64
n_best_size = 20
max_answer_length = 30
is_squad_v2 = False  # SQuAD 2.0 has examples with no answer, aka "impossible", SQuAD 1.0 does not
verbose_logging = False
null_score_diff_threshold = 0.0

Initialize the configuration handler

from critical_path.BERT.configs import ConfigClassifier

Flags = ConfigClassifier()
Flags.set_model_paths(
    bert_config_file=base_model_folder_path + name_of_config_json_file,
    bert_vocab_file=base_model_folder_path + name_of_vocab_file,
    bert_output_dir=output_folder_path,
    data_dir=data_dir)

Flags.set_model_params(
    batch_size_train=8, 
    max_seq_length=256,
    num_train_epochs=3)

# Retrieve a handle for the configs
FLAGS = Flags.get_handle()

A single 1070GTX using BERT-Base Uncased can handle

Model max_seq_len batch_size
BERT-Base Uncased 256 6
... 384 4

For full batch size and sequence length guidelines see Google's recommendations

Using Configured Model

First, create a new model with the configured parameters

"""For Multi-Label Classification"""
from critical_path.BERT.model_multilabel_class import MultiLabelClassifier

model = MultiLabelClassifier(FLAGS)

Second, load your data source

  • SQuAD has dedicated dataloaders
  • Multi-Label Classification has a generic dataloader
    • DataProcessor in /BERT/model_multilabel_class
      • Note: This requires data labels to be in string format
      • labels = [
            ["label_1", "label_2", "label_3"],
            ["label_2"]
        ]
        
  • Single-Label Classification dataloaders
"""For Multi-Label Classification with a custom .csv reading function"""
from critical_path.BERT.model_multilabel_class import DataProcessor

# read_data is dataset specifc - see /bert_multilabel_example.py
input_ids, input_text, input_labels, label_list = read_toxic_data(randomize=True)

processor = DataProcessor(label_list=label_list)
train_examples = processor.get_samples(
        input_ids=input_ids,
        input_text=input_text,
        input_labels=input_labels,
        set_type='train')

Third, run your task

"""Train and predict a Multi-Label Classifier"""

if do_train:
  model.train(train_examples, label_list)

if do_predict:
  model.predict(predict_examples, label_list)

For full examples please see:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

critical_path-0.0.9.tar.gz (40.1 kB view details)

Uploaded Source

Built Distribution

critical_path-0.0.9-py3-none-any.whl (53.4 kB view details)

Uploaded Python 3

File details

Details for the file critical_path-0.0.9.tar.gz.

File metadata

  • Download URL: critical_path-0.0.9.tar.gz
  • Upload date:
  • Size: 40.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/41.0.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.6.8

File hashes

Hashes for critical_path-0.0.9.tar.gz
Algorithm Hash digest
SHA256 a3efe908e1005292ae2b2a047120cfc6d7d501244d68ab601cd89ef4127f5cd1
MD5 3235b59eaeebb800def89894c2c55d7e
BLAKE2b-256 db2f074d61f7de774c4074edfe9eb863799fe534f8788adb12822c63b65aa26c

See more details on using hashes here.

File details

Details for the file critical_path-0.0.9-py3-none-any.whl.

File metadata

  • Download URL: critical_path-0.0.9-py3-none-any.whl
  • Upload date:
  • Size: 53.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/41.0.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.6.8

File hashes

Hashes for critical_path-0.0.9-py3-none-any.whl
Algorithm Hash digest
SHA256 9632300318eaa1482e696ad59275f65162a66586d2b12d60e1626169bcae95c9
MD5 6bb14c462e014044d7deabf80204433b
BLAKE2b-256 66185b839386fb54ceab87f3d768a8af565b5658c4efc6ef244d3254593d97dd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page