Skip to main content

AI Library using BERT

Project description

Fast-Bert

*** NEW *** Now supports BERT and XLNet for both Multi-Class and Multi-Label text classification.

Fast-Bert is the deep learning library that allows developers and data scientists to train and deploy BERT and XLNet based models for natural language processing tasks beginning with Text Classification.

The work on FastBert is built on solid foundations provided by the excellent Hugging Face BERT PyTorch library and is inspired by fast.ai and strives to make the cutting edge deep learning technologies accessible for the vast community of machine learning practitioners.

With FastBert, you will be able to:

  1. Train (more precisely fine-tune) BERT and XLNet text classification models on your custom dataset.

  2. Tune model hyper-parameters such as epochs, learning rate, batch size, optimiser schedule and more.

  3. Save and deploy trained model for inference (including on AWS Sagemaker).

Fast-Bert will support both multi-class and multi-label text classification for the following and in due course, it will support other NLU tasks such as Named Entity Recognition, Question Answering and Custom Corpus fine-tuning.

  1. BERT (from Google) released with the paper BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.

  2. XLNet (from Google/CMU) released with the paper ​XLNet: Generalized Autoregressive Pretraining for Language Understanding by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.

Installation

This repo is tested on Python 3.6+.

With pip

PyTorch-Transformers can be installed by pip as follows:

pip install fast-bert

From source

Clone the repository and run:

pip install [--editable] .

or

pip install git+https://github.com/kaushaltrivedi/fast-bert.git

Usage

1. Create a DataBunch object

The databunch object takes training, validation and test csv files and converts the data into internal representation for BERT or XLNet. The object also instantiates the correct data-loaders based on device profile and batch_size and max_sequence_length.

from fast_bert.data_cls import BertDataBunch

databunch = BertDataBunch(DATA_PATH, LABEL_PATH,
                          tokenizer='bert-base-uncased',
                          train_file='train.csv',
                          val_file='val.csv',
                          label_file='labels.csv',
                          text_col='text',
                          label_col='label',
                          batch_size_per_gpu=16,
                          max_seq_length=512,
                          multi_gpu=True,
                          multi_label=False,
                          model_type='bert')

File format for train.csv and val.csv

index text label
0 Looking through the other comments, I'm amazed that there aren't any warnings to potential viewers of what they have to look forward to when renting this garbage. First off, I rented this thing with the understanding that it was a competently rendered Indiana Jones knock-off. neg
1 I've watched the first 17 episodes and this series is simply amazing! I haven't been this interested in an anime series since Neon Genesis Evangelion. This series is actually based off an h-game, which I'm not sure if it's been done before or not, I haven't played the game, but from what I've heard it follows it very well pos
2 his movie is nothing short of a dark, gritty masterpiece. I may be bias, as the Apartheid era is an area I've always felt for. pos

In case the column names are different than the usual text and labels, you will have to provide those names in the databunch text_col and label_col parameters.

labels.csv will contain a list of all unique labels. In this case the file will contain:

pos
neg

For multi-label classification the file will contain multiple labels with each label value being either 0 or 1.

id text toxic severe_toxic obscene threat insult identity_hate
0 Why the edits made under my username Hardcore Metallica Fan were reverted? 0 0 0 0 0 0
0 I will mess you up 1 0 0 1 0 0

label_col will be a list of label column names. In this case it will be:

['toxic','severe_toxic','obscene','threat','insult','identity_hate']

Tokenizer

You can either create a tokenizer object and pass it to DataBunch or you can pass the model name as tokenizer and DataBunch will automatically download and instantiate an appropriate tokenizer object.

For example for using XLNet base cased model, set tokenizer parameter to 'xlnet-base-cased'. DataBunch will automatically download and instantiate XLNetTokenizer with the vocabulary for xlnet-base-cased model.

Model Type

Fast-Bert supports XLNet and BERT based classification models. Set model type parameter value to 'bert' or 'xlnet' in order to initiate an appropriate databunch object.

2. Create a Learner Object

BertLearner is the ‘learner’ object that holds everything together. It encapsulates the key logic for the lifecycle of the model such as training, validation and inference.

The learner object will take the databunch created earlier as as input alongwith some of the other parameters such as location for one of the pretrained models, FP16 training, multi_gpu and multi_label options.

The learner class contains the logic for training loop, validation loop, optimiser strategies and key metrics calculation. This help the developers focus on their custom use-cases without worrying about these repetitive activities.

At the same time the learner object is flexible enough to be customised either via using flexible parameters or by creating a subclass of BertLearner and redefining relevant methods.

from fast_bert.learner_cls import BertLearner
from fast_bert.metrics import accuracy
import logging

logger = logging.getLogger()
device_cuda = torch.device("cuda")
metrics = [{'name': 'accuracy', 'function': accuracy}]

learner = BertLearner.from_pretrained_model(
						databunch,
						pretrained_path='bert-base-uncased',
						metrics=metrics,
						device=device_cuda,
						logger=logger,
						output_dir=OUTPUT_DIR,
						finetuned_wgts_path=None,
						warmup_steps=500,
						multi_gpu=True,
						is_fp16=True,
						multi_label=False,
						logging_steps=50)
parameter description
databunch Databunch object created earlier
pretrained_path Directory for the location of the pretrained model files or the name of one of the pretrained models i.e. bert-base-uncased, xlnet-large-cased, etc
metrics List of metrics functions that you want the model to calculate on the validation set, e.g. accuracy, beta, etc
device torch.device of type cuda or cpu
logger logger object
output_dir Directory for model to save trained artefacts, tokenizer vocabulary and tensorboard files
finetuned_wgts_path provide the location for fine-tuned language model (experimental feature)
warmup_steps number of training warms steps for the scheduler
multi_gpu multiple GPUs available e.g. if running on AWS p3.8xlarge instance
is_fp16 FP16 training
multi_label multilabel classification
logging_steps number of steps between each tensorboard metrics calculation. Set it to 0 to disable tensor flow logging. Keeping this value too low will lower the training speed as model will be evaluated each time the metrics are logged

3. Train the model

learner.fit(epochs=6, 
			lr=6e-5, 
			validate=True. 	# Evaluate the model after each epoch
			schedule_type="warmup_cosine") 

4. Save trained model artifacts

learner.save_model()

Model artefacts will be persisted in the output_dir/'model_out' path provided to the learner object. Following files will be persisted:

File name description
pytorch_model.bin trained model weights
spiece.model sentence tokenizer vocabulary (for xlnet models)
vocab.txt workpiece tokenizer vocabulary (for bert models)
special_tokens_map.json special tokens mappings
config.json model config
added_tokens.json list of new tokens

As the model artefacts are all stored in the same folder, you will be able to instantiate the learner object to run inference by pointing pretrained_path to this location.

5. Model Inference

If you already have a Learner object with trained model instantiated, just call predict_batch method on the learner object with the list of text data:

texts = ['I really love the Netflix original movies',
		 'this movie is not worth watching']
predictions = learner.predict_batch(texts)

If you have persistent trained model and just want to run inference logic on that trained model, use the second approach, i.e. the predictor object.

from fast_bert.prediction import BertClassificationPredictor

MODEL_PATH = OUTPUT_DIR/'model_out'

predictor = BertClassificationPredictor(
				model_path=MODEL_PATH,
				label_path=LABEL_PATH, # location for labels.csv file
				multi_label=False,
				model_type='xlnet', 
				do_lower_case=False)

# Single prediction
single_prediction = predictor.predict("just get me result for this text")

# Batch predictions
texts = [
	"this is the first text",
	"this is the second text"
	]

multiple_predictions = predictor.predict(texts)

AWS Sagemaker Support

The purpose of this library is to let you train and deploy production grade models. As transformer models require expensive GPUs to train, I have added support for training and deploying model on AWS SageMaker.

The repository contains the docker image and code for building BERT and XLNet models in SageMaker. Due to the sheer number of breaking changes in Fast-Bert and the underlying pytorch-transformers libraries, at present, the SageMaker will support the older version of Fast-Bert library.
I am hoping to update this in coming weeks.

Citation

Please include a mention of this library and HuggingFace pytorch-transformers library and a link to the present repository if you use this work in a published or open-source project.

Also include my blogs on this topic:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fast_bert-1.0.0.tar.gz (34.4 kB view hashes)

Uploaded Source

Built Distribution

fast_bert-1.0.0-py3-none-any.whl (43.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page