Tools for adapting universal language models to specifc tasks
Project description
critical-path-nlp
Tools for adapting universal language models to specifc tasks
Please note these are tools for rapid prototyping - not brute force hyperparameter tuning.
Adapted from: Google's BERT
Installation
- Clone repo (just for now - will be on pip soon)
- Download a pretrained BERT model - start with BERT-Base Uncased if you're not sure where to begin
- Unzip the model and make note of the path
Examples
- Full implementation examples can be found here:
Current Capabilities
BERT for Question Answering
- Train and evaluate the SQuAD dataset
BERT for Multi-Label Classification
- Train and evaluate custom datasets for multi-label classification tasks (multiple labels possible)
BERT for Single-Label Classification
- Train and evaluate custom datasets for single-label classification tasks (one label possible)
Future Capabilities
GPT-2 Training and Generation
Usage + Core Components
Configuring BERT
First, define the model paths
base_model_folder_path = "../models/uncased_L-12_H-768_A-12/" # Folder containing downloaded Base Model
name_of_config_json_file = "bert_config.json" # Located inside the Base Model folder
name_of_vocab_file = "vocab.txt" # Located inside the Base Model folder
output_directory = "../models/trained_BERT/" # Trained model and results landing folder
# Multi-Label and Single-Label Specific
data_dir = None # Directory .tsv data is stored in - typically for CoLA/MPRC or other datasets with known structure
Second, define the model run parameters
"""Settable parameters and their default values
Note: Most default values are perfectly fine
"""
# Administrative
init_checkpoint = None
save_checkpoints_steps = 1000
iterations_per_loop = 1000
do_lower_case = True
# Technical
batch_size_train = 32
batch_size_eval = 8
batch_size_predict = 8
num_train_epochs = 3.0
max_seq_length = 128
warmup_proportion = 0.1
learning_rate = 3e-5
# SQuAD Specific
doc_stride = 128
max_query_length = 64
n_best_size = 20
max_answer_length = 30
is_squad_v2 = False # SQuAD 2.0 has examples with no answer, aka "impossible", SQuAD 1.0 does not
verbose_logging = False
null_score_diff_threshold = 0.0
Initialize the configuration handler
from critical_path.BERT.configs import ConfigClassifier
Flags = ConfigClassifier()
Flags.set_model_paths(
bert_config_file=base_model_folder_path + name_of_config_json_file,
bert_vocab_file=base_model_folder_path + name_of_vocab_file,
bert_output_dir=output_folder_path,
data_dir=data_dir)
Flags.set_model_params(
batch_size_train=8,
max_seq_length=256,
num_train_epochs=3)
# Retrieve a handle for the configs
FLAGS = Flags.get_handle()
A single 1070GTX using BERT-Base Uncased can handle
Model | max_seq_len | batch_size |
---|---|---|
BERT-Base Uncased | 256 | 6 |
... | 384 | 4 |
For full batch size and sequence length guidelines see Google's recommendations
Using Configured Model
First, create a new model with the configured parameters
"""For Multi-Label Classification"""
from critical_path.BERT.model_multilabel_class import MultiLabelClassifier
model = MultiLabelClassifier(FLAGS)
Second, load your data source
- SQuAD has dedicated dataloaders
- read_squad_examples(), write_squad_predictions() in /BERT/model_squad
- Multi-Label Classification has a generic dataloader
- DataProcessor in /BERT/model_multilabel_class
- Note: This requires data labels to be in string format
-
labels = [ ["label_1", "label_2", "label_3"], ["label_2"] ]
- DataProcessor in /BERT/model_multilabel_class
- Single-Label Classification dataloaders
- ColaProcessor is implemented in /BERT/model_classifier
- More dataloader formats have been done by pytorch-pretrained-BERT
"""For Multi-Label Classification with a custom .csv reading function"""
from critical_path.BERT.model_multilabel_class import DataProcessor
# read_data is dataset specifc - see /bert_multilabel_example.py
input_ids, input_text, input_labels, label_list = read_toxic_data(randomize=True)
processor = DataProcessor(label_list=label_list)
train_examples = processor.get_samples(
input_ids=input_ids,
input_text=input_text,
input_labels=input_labels,
set_type='train')
Third, run your task
"""Train and predict a Multi-Label Classifier"""
if do_train:
model.train(train_examples, label_list)
if do_predict:
model.predict(predict_examples, label_list)
For full examples please see:
- Full implementations:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
critical_path-0.0.9.tar.gz
(40.1 kB
view details)
Built Distribution
File details
Details for the file critical_path-0.0.9.tar.gz
.
File metadata
- Download URL: critical_path-0.0.9.tar.gz
- Upload date:
- Size: 40.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/41.0.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.6.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a3efe908e1005292ae2b2a047120cfc6d7d501244d68ab601cd89ef4127f5cd1 |
|
MD5 | 3235b59eaeebb800def89894c2c55d7e |
|
BLAKE2b-256 | db2f074d61f7de774c4074edfe9eb863799fe534f8788adb12822c63b65aa26c |
File details
Details for the file critical_path-0.0.9-py3-none-any.whl
.
File metadata
- Download URL: critical_path-0.0.9-py3-none-any.whl
- Upload date:
- Size: 53.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/41.0.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.6.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9632300318eaa1482e696ad59275f65162a66586d2b12d60e1626169bcae95c9 |
|
MD5 | 6bb14c462e014044d7deabf80204433b |
|
BLAKE2b-256 | 66185b839386fb54ceab87f3d768a8af565b5658c4efc6ef244d3254593d97dd |