Skip to main content

AI Library using BERT

Project description

Fast-Bert

License Apache 2.0 PyPI version Python 3.6, 3.7

New - Learning Rate Finder for Text Classification Training (borrowed with thanks from https://github.com/davidtvs/pytorch-lr-finder)

Supports LAMB optimizer for faster training. Please refer to https://arxiv.org/abs/1904.00962 for the paper on LAMB optimizer.

Supports BERT and XLNet for both Multi-Class and Multi-Label text classification.

Fast-Bert is the deep learning library that allows developers and data scientists to train and deploy BERT and XLNet based models for natural language processing tasks beginning with Text Classification.

The work on FastBert is built on solid foundations provided by the excellent Hugging Face BERT PyTorch library and is inspired by fast.ai and strives to make the cutting edge deep learning technologies accessible for the vast community of machine learning practitioners.

With FastBert, you will be able to:

  1. Train (more precisely fine-tune) BERT, RoBERTa and XLNet text classification models on your custom dataset.

  2. Tune model hyper-parameters such as epochs, learning rate, batch size, optimiser schedule and more.

  3. Save and deploy trained model for inference (including on AWS Sagemaker).

Fast-Bert will support both multi-class and multi-label text classification for the following and in due course, it will support other NLU tasks such as Named Entity Recognition, Question Answering and Custom Corpus fine-tuning.

  1. BERT (from Google) released with the paper BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
  1. XLNet (from Google/CMU) released with the paper ​XLNet: Generalized Autoregressive Pretraining for Language Understanding by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.

  2. RoBERTa (from Facebook), a Robustly Optimized BERT Pretraining Approach by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du et al.

  3. DistilBERT (from HuggingFace), released together with the blogpost Smaller, faster, cheaper, lighter: Introducing DistilBERT, a distilled version of BERT by Victor Sanh, Lysandre Debut and Thomas Wolf.

Installation

This repo is tested on Python 3.6+.

With pip

PyTorch-Transformers can be installed by pip as follows:

pip install fast-bert

From source

Clone the repository and run:

pip install [--editable] .

or

pip install git+https://github.com/kaushaltrivedi/fast-bert.git

You will also need to install NVIDIA Apex.

git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Usage

Text Classification

1. Create a DataBunch object

The databunch object takes training, validation and test csv files and converts the data into internal representation for BERT, RoBERTa, DistilBERT or XLNet. The object also instantiates the correct data-loaders based on device profile and batch_size and max_sequence_length.

from fast_bert.data_cls import BertDataBunch

databunch = BertDataBunch(DATA_PATH, LABEL_PATH,
                          tokenizer='bert-base-uncased',
                          train_file='train.csv',
                          val_file='val.csv',
                          label_file='labels.csv',
                          text_col='text',
                          label_col='label',
                          batch_size_per_gpu=16,
                          max_seq_length=512,
                          multi_gpu=True,
                          multi_label=False,
                          model_type='bert')

File format for train.csv and val.csv

index text label
0 Looking through the other comments, I'm amazed that there aren't any warnings to potential viewers of what they have to look forward to when renting this garbage. First off, I rented this thing with the understanding that it was a competently rendered Indiana Jones knock-off. neg
1 I've watched the first 17 episodes and this series is simply amazing! I haven't been this interested in an anime series since Neon Genesis Evangelion. This series is actually based off an h-game, which I'm not sure if it's been done before or not, I haven't played the game, but from what I've heard it follows it very well pos
2 his movie is nothing short of a dark, gritty masterpiece. I may be bias, as the Apartheid era is an area I've always felt for. pos

In case the column names are different than the usual text and labels, you will have to provide those names in the databunch text_col and label_col parameters.

labels.csv will contain a list of all unique labels. In this case the file will contain:

pos
neg

For multi-label classification, labels.csv will contain all possible labels:

toxic
severe_toxic
obscene
threat
insult
identity_hate

The file train.csv will then contain one column for each label, with each column value being either 0 or 1. Don't forget to change multi_label=True for multi-label classification in BertDataBunch.

id text toxic severe_toxic obscene threat insult identity_hate
0 Why the edits made under my username Hardcore Metallica Fan were reverted? 0 0 0 0 0 0
0 I will mess you up 1 0 0 1 0 0

label_col will be a list of label column names. In this case it will be:

['toxic','severe_toxic','obscene','threat','insult','identity_hate']

Tokenizer

You can either create a tokenizer object and pass it to DataBunch or you can pass the model name as tokenizer and DataBunch will automatically download and instantiate an appropriate tokenizer object.

For example for using XLNet base cased model, set tokenizer parameter to 'xlnet-base-cased'. DataBunch will automatically download and instantiate XLNetTokenizer with the vocabulary for xlnet-base-cased model.

Model Type

Fast-Bert supports XLNet, RoBERTa and BERT based classification models. Set model type parameter value to 'bert', roberta or 'xlnet' in order to initiate an appropriate databunch object.

2. Create a Learner Object

BertLearner is the ‘learner’ object that holds everything together. It encapsulates the key logic for the lifecycle of the model such as training, validation and inference.

The learner object will take the databunch created earlier as as input alongwith some of the other parameters such as location for one of the pretrained models, FP16 training, multi_gpu and multi_label options.

The learner class contains the logic for training loop, validation loop, optimiser strategies and key metrics calculation. This help the developers focus on their custom use-cases without worrying about these repetitive activities.

At the same time the learner object is flexible enough to be customised either via using flexible parameters or by creating a subclass of BertLearner and redefining relevant methods.

from fast_bert.learner_cls import BertLearner
from fast_bert.metrics import accuracy
import logging

logger = logging.getLogger()
device_cuda = torch.device("cuda")
metrics = [{'name': 'accuracy', 'function': accuracy}]

learner = BertLearner.from_pretrained_model(
						databunch,
						pretrained_path='bert-base-uncased',
						metrics=metrics,
						device=device_cuda,
						logger=logger,
						output_dir=OUTPUT_DIR,
						finetuned_wgts_path=None,
						warmup_steps=500,
						multi_gpu=True,
						is_fp16=True,
						multi_label=False,
						logging_steps=50)
parameter description
databunch Databunch object created earlier
pretrained_path Directory for the location of the pretrained model files or the name of one of the pretrained models i.e. bert-base-uncased, xlnet-large-cased, etc
metrics List of metrics functions that you want the model to calculate on the validation set, e.g. accuracy, beta, etc
device torch.device of type cuda or cpu
logger logger object
output_dir Directory for model to save trained artefacts, tokenizer vocabulary and tensorboard files
finetuned_wgts_path provide the location for fine-tuned language model (experimental feature)
warmup_steps number of training warms steps for the scheduler
multi_gpu multiple GPUs available e.g. if running on AWS p3.8xlarge instance
is_fp16 FP16 training
multi_label multilabel classification
logging_steps number of steps between each tensorboard metrics calculation. Set it to 0 to disable tensor flow logging. Keeping this value too low will lower the training speed as model will be evaluated each time the metrics are logged

3. Find the optimal learning rate

The learning rate is one of the most important hyperparameters for model training. We have incorporated the learining rate finder that was proposed by Leslie Smith and then built into the fastai library.

learner.lr_find(start_lr=1e-5,optimizer_type='lamb')

The code is heavily borrowed from David Silva's pytorch-lr-finder library.

Learning rate range test

4. Train the model

learner.fit(epochs=6,
			lr=6e-5,
			validate=True, 	# Evaluate the model after each epoch
			schedule_type="warmup_cosine",
			optimizer_type="lamb")

Fast-Bert now supports LAMB optmizer. Due to the speed of training, we have set LAMB as the default optimizer. You can switch back to AdamW by setting optimizer_type to 'adamw'.

5. Save trained model artifacts

learner.save_model()

Model artefacts will be persisted in the output_dir/'model_out' path provided to the learner object. Following files will be persisted:

File name description
pytorch_model.bin trained model weights
spiece.model sentence tokenizer vocabulary (for xlnet models)
vocab.txt workpiece tokenizer vocabulary (for bert models)
special_tokens_map.json special tokens mappings
config.json model config
added_tokens.json list of new tokens

As the model artefacts are all stored in the same folder, you will be able to instantiate the learner object to run inference by pointing pretrained_path to this location.

6. Model Inference

If you already have a Learner object with trained model instantiated, just call predict_batch method on the learner object with the list of text data:

texts = ['I really love the Netflix original movies',
		 'this movie is not worth watching']
predictions = learner.predict_batch(texts)

If you have persistent trained model and just want to run inference logic on that trained model, use the second approach, i.e. the predictor object.

from fast_bert.prediction import BertClassificationPredictor

MODEL_PATH = OUTPUT_DIR/'model_out'

predictor = BertClassificationPredictor(
				model_path=MODEL_PATH,
				label_path=LABEL_PATH, # location for labels.csv file
				multi_label=False,
				model_type='xlnet',
				do_lower_case=False,
				device=None) # set custom torch.device, defaults to cuda if available

# Single prediction
single_prediction = predictor.predict("just get me result for this text")

# Batch predictions
texts = [
	"this is the first text",
	"this is the second text"
	]

multiple_predictions = predictor.predict_batch(texts)

Language Model Fine-tuning

A useful approach to use BERT based models on custom datasets is to first finetune the language model task for the custom dataset, an apporach followed by fast.ai's ULMFit. The idea is to start with a pre-trained model and further train the model on the raw text of the custom dataset. We will use the masked LM task to finetune the language model.

This section will describe the usage of FastBert to finetune the language model.

1. Import the necessary libraries

The necessary objects are stored in the files with '_lm' suffix.

# Language model Databunch
from fast_bert.data_lm import BertLMDataBunch
# Language model learner
from fast_bert.learner_lm import BertLMLearner

from pathlib import Path
from box import Box

2. Define parameters and setup datapaths

# Box is a nice wrapper to create an object from a json dict
args = Box({
    "seed": 42,
    "task_name": 'imdb_reviews_lm',
    "model_name": 'roberta-base',
    "model_type": 'roberta',
    "train_batch_size": 16,
    "learning_rate": 4e-5,
    "num_train_epochs": 20,
    "fp16": True,
    "fp16_opt_level": "O2",
    "warmup_steps": 1000,
    "logging_steps": 0,
    "max_seq_length": 512,
    "multi_gpu": True if torch.cuda.device_count() > 1 else False
})

DATA_PATH = Path('../lm_data/')
LOG_PATH = Path('../logs')
MODEL_PATH = Path('../lm_model_{}/'.format(args.model_type))

DATA_PATH.mkdir(exist_ok=True)
MODEL_PATH.mkdir(exist_ok=True)
LOG_PATH.mkdir(exist_ok=True)

3. Create DataBunch object

The BertLMDataBunch class contains a static method 'from_raw_corpus' that will take the list of raw texts and create DataBunch for the language model learner.

The method will at first preprocess the text list by removing html tags, extra spaces and more and then create files lm_train.txt and lm_val.txt. These files will be used for training and evaluating the language model finetuning task.

The next step will be to featurize the texts. The text will be tokenized, numericalized and split into blocks on 512 tokens (including special tokens).

databunch_lm = BertLMDataBunch.from_raw_corpus(
					data_dir=DATA_PATH,
					text_list=texts,
					tokenizer=args.model_name,
					batch_size_per_gpu=args.train_batch_size,
					max_seq_length=args.max_seq_length,
                    multi_gpu=args.multi_gpu,
                    model_type=args.model_type,
                    logger=logger)

As this step can take some time based on the size of your custom dataset's text, the featurized data will be cached in pickled files in the data_dir/lm_cache folder.

The next time, instead of using from_raw_corpus method, you may want to directly instantiate the DataBunch object as shown below:

databunch_lm = BertLMDataBunch(
						data_dir=DATA_PATH,
						tokenizer=args.model_name,
                        batch_size_per_gpu=args.train_batch_size,
                        max_seq_length=args.max_seq_length,
                        multi_gpu=args.multi_gpu,
                        model_type=args.model_type,
                        logger=logger)

4. Create the LM Learner object

BertLearner is the ‘learner’ object that holds everything together. It encapsulates the key logic for the lifecycle of the model such as training, validation and inference.

The learner object will take the databunch created earlier as as input alongwith some of the other parameters such as location for one of the pretrained models, FP16 training, multi_gpu and multi_label options.

The learner class contains the logic for training loop, validation loop, and optimizer strategies. This help the developers focus on their custom use-cases without worrying about these repetitive activities.

At the same time the learner object is flexible enough to be customized either via using flexible parameters or by creating a subclass of BertLearner and redefining relevant methods.

learner = BertLMLearner.from_pretrained_model(
							dataBunch=databunch_lm,
							pretrained_path=args.model_name,
							output_dir=MODEL_PATH,
							metrics=[],
							device=device,
							logger=logger,
							multi_gpu=args.multi_gpu,
							logging_steps=args.logging_steps,
							fp16_opt_level=args.fp16_opt_level)

5. Train the model

learner.fit(epochs=6,
			lr=6e-5,
			validate=True, 	# Evaluate the model after each epoch
			schedule_type="warmup_cosine",
			optimizer_type="lamb")

Fast-Bert now supports LAMB optmizer. Due to the speed of training, we have set LAMB as the default optimizer. You can switch back to AdamW by setting optimizer_type to 'adamw'.

6. Save trained model artifacts

learner.save_model()

Model artefacts will be persisted in the output_dir/'model_out' path provided to the learner object. Following files will be persisted:

File name description
pytorch_model.bin trained model weights
spiece.model sentence tokenizer vocabulary (for xlnet models)
vocab.txt workpiece tokenizer vocabulary (for bert models)
special_tokens_map.json special tokens mappings
config.json model config
added_tokens.json list of new tokens

The pytorch_model.bin contains the finetuned weights and you can point the classification task learner object to this file throgh the finetuned_wgts_path parameter.

Amazon Sagemaker Support

The purpose of this library is to let you train and deploy production grade models. As transformer models require expensive GPUs to train, I have added support for training and deploying model on AWS SageMaker.

The repository contains the docker image and code for building BERT based classification models in Amazon SageMaker.

Please refer to my blog Train and Deploy the Mighty BERT based NLP models using FastBert and Amazon SageMaker that provides detailed explanation on using SageMaker with FastBert.

Citation

Please include a mention of this library and HuggingFace pytorch-transformers library and a link to the present repository if you use this work in a published or open-source project.

Also include my blogs on this topic:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fast_bert_no_plot-1.9.2.tar.gz (79.2 kB view details)

Uploaded Source

Built Distribution

fast_bert_no_plot-1.9.2-py3-none-any.whl (90.6 kB view details)

Uploaded Python 3

File details

Details for the file fast_bert_no_plot-1.9.2.tar.gz.

File metadata

  • Download URL: fast_bert_no_plot-1.9.2.tar.gz
  • Upload date:
  • Size: 79.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/46.0.0 requests-toolbelt/0.9.1 tqdm/4.50.2 CPython/3.7.7

File hashes

Hashes for fast_bert_no_plot-1.9.2.tar.gz
Algorithm Hash digest
SHA256 aad68be87cdfd714487af8b4dd400e2d93d0f09d65c4477a2bcaf9afd9bfc9ae
MD5 39f8a2b2363bd7d754bb5064b85938b4
BLAKE2b-256 bda4b6886e46324c6522444e1569e62ad44d789a6455af1f457e9c06cc4ac3db

See more details on using hashes here.

File details

Details for the file fast_bert_no_plot-1.9.2-py3-none-any.whl.

File metadata

  • Download URL: fast_bert_no_plot-1.9.2-py3-none-any.whl
  • Upload date:
  • Size: 90.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/46.0.0 requests-toolbelt/0.9.1 tqdm/4.50.2 CPython/3.7.7

File hashes

Hashes for fast_bert_no_plot-1.9.2-py3-none-any.whl
Algorithm Hash digest
SHA256 30efc8539c3f2e5f8dfa78d4f4a55a2461a797c4130629ac4ee8dc6cd175cbfb
MD5 48f9c445132236513e7b8342827ffc5b
BLAKE2b-256 2fd0a11463b2f2f9b99fb32594ea3211ce024b58ab93d8b1dc7161d5c9234151

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page