Tools for adapting universal language models to specifc tasks
Project description
critical-path-nlp
Tools for adapting universal language models to specifc tasks
Please note these are tools for rapid prototyping - not brute force hyperparameter tuning.
Adapted from: Google's BERT
Installation
- Clone repo (just for now - will be on pip soon)
- Download a pretrained BERT model - start with BERT-Base Uncased if you're not sure where to begin
- Unzip the model and make note of the path
Examples
- Full implementation examples can be found here:
Current Capabilities
BERT for Question Answering
- Train and evaluate the SQuAD dataset
BERT for Multi-Label Classification
- Train and evaluate custom datasets for multi-label classification tasks (multiple labels possible)
BERT for Single-Label Classification
- Train and evaluate custom datasets for single-label classification tasks (one label possible)
Future Capabilities
GPT-2 Training and Generation
Usage + Core Components
Configuring BERT
First, define the model paths
base_model_folder_path = "../models/uncased_L-12_H-768_A-12/" # Folder containing downloaded Base Model
name_of_config_json_file = "bert_config.json" # Located inside the Base Model folder
name_of_vocab_file = "vocab.txt" # Located inside the Base Model folder
output_directory = "../models/trained_BERT/" # Trained model and results landing folder
# Multi-Label and Single-Label Specific
data_dir = None # Directory .tsv data is stored in - typically for CoLA/MPRC or other datasets with known structure
Second, define the model run parameters
"""Settable parameters and their default values
Note: Most default values are perfectly fine
"""
# Administrative
init_checkpoint = None
save_checkpoints_steps = 1000
iterations_per_loop = 1000
do_lower_case = True
# Technical
batch_size_train = 32
batch_size_eval = 8
batch_size_predict = 8
num_train_epochs = 3.0
max_seq_length = 128
warmup_proportion = 0.1
learning_rate = 3e-5
# SQuAD Specific
doc_stride = 128
max_query_length = 64
n_best_size = 20
max_answer_length = 30
is_squad_v2 = False # SQuAD 2.0 has examples with no answer, aka "impossible", SQuAD 1.0 does not
verbose_logging = False
null_score_diff_threshold = 0.0
Initialize the configuration handler
from critical_path.BERT.configs import ConfigClassifier
Flags = ConfigClassifier()
Flags.set_model_paths(
bert_config_file=base_model_folder_path + name_of_config_json_file,
bert_vocab_file=base_model_folder_path + name_of_vocab_file,
bert_output_dir=output_folder_path,
data_dir=data_dir)
Flags.set_model_params(
batch_size_train=8,
max_seq_length=256,
num_train_epochs=3)
# Retrieve a handle for the configs
FLAGS = Flags.get_handle()
A single 1070GTX using BERT-Base Uncased can handle
| Model | max_seq_len | batch_size |
|---|---|---|
| BERT-Base Uncased | 256 | 6 |
| ... | 384 | 4 |
For full batch size and sequence length guidelines see Google's recommendations
Using Configured Model
First, create a new model with the configured parameters
"""For Multi-Label Classification"""
from critical_path.BERT.model_multilabel_class import MultiLabelClassifier
model = MultiLabelClassifier(FLAGS)
Second, load your data source
- SQuAD has dedicated dataloaders
- read_squad_examples(), write_squad_predictions() in /BERT/model_squad
- Multi-Label Classification has a generic dataloader
- DataProcessor in /BERT/model_multilabel_class
- Note: This requires data labels to be in string format
-
labels = [ ["label_1", "label_2", "label_3"], ["label_2"] ]
- DataProcessor in /BERT/model_multilabel_class
- Single-Label Classification dataloaders
- ColaProcessor is implemented in /BERT/model_classifier
- More dataloader formats have been done by pytorch-pretrained-BERT
"""For Multi-Label Classification with a custom .csv reading function"""
from critical_path.BERT.model_multilabel_class import DataProcessor
# read_data is dataset specifc - see /bert_multilabel_example.py
input_ids, input_text, input_labels, label_list = read_toxic_data(randomize=True)
processor = DataProcessor(label_list=label_list)
train_examples = processor.get_samples(
input_ids=input_ids,
input_text=input_text,
input_labels=input_labels,
set_type='train')
Third, run your task
"""Train and predict a Multi-Label Classifier"""
if do_train:
model.train(train_examples, label_list)
if do_predict:
model.predict(predict_examples, label_list)
For full examples please see:
- Full implementations:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file critical_path-0.0.9.tar.gz.
File metadata
- Download URL: critical_path-0.0.9.tar.gz
- Upload date:
- Size: 40.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/41.0.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.6.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a3efe908e1005292ae2b2a047120cfc6d7d501244d68ab601cd89ef4127f5cd1
|
|
| MD5 |
3235b59eaeebb800def89894c2c55d7e
|
|
| BLAKE2b-256 |
db2f074d61f7de774c4074edfe9eb863799fe534f8788adb12822c63b65aa26c
|
File details
Details for the file critical_path-0.0.9-py3-none-any.whl.
File metadata
- Download URL: critical_path-0.0.9-py3-none-any.whl
- Upload date:
- Size: 53.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/41.0.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.6.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9632300318eaa1482e696ad59275f65162a66586d2b12d60e1626169bcae95c9
|
|
| MD5 |
6bb14c462e014044d7deabf80204433b
|
|
| BLAKE2b-256 |
66185b839386fb54ceab87f3d768a8af565b5658c4efc6ef244d3254593d97dd
|