Adversarial Training and Data Augmentation for Neural Question-Answering Models
Project description
KitanaQA
Tool[KIT] for [A]dversarial Trai[N]ing and [A]ugmentation in [Q]uestion [A]nswering
About •
Features •
Install •
Getting Started •
Examples
About
KitanaQA is an adversarial training and data augmentation framework for fine-tuning Transformer-based language models on question-answering datasets
Why KitanaQA?
While NLP models have made incredible progress on curated question-answer datasets in recent years, they are still brittle and unpredictable in production environments, making productization and enterprise adoption problematic. KitanaQA provides resources to "robustify" Transformer-based question-answer models against many types of natural and synthetic noise. The major features are:
- Adversarial Training can increase both robustness and performance of fine-tuned Transformer QA models. Here, we implement virtual adversarial training, which introduces embedding-space perturbations during fine-tuning to encourage the model to produce more stable results in the presence of noisy inputs.
Our experiments with BERT finetuned on the SQuAD v1.1 question answering dataset show a marked improvement in f1 and em scores:
Model | em | f1 |
---|---|---|
BERT-base | 80.8 | 88.5 |
BERT-base (ALUM) | 81.97 | 88.92 |
-
Augment Your Dataset to increase model generalizability and robustness using token-level perturbations. While Adversarial Training provides some measure of robustness against bounded perturbations, Augmentation can accomodate a wide range of naturally-occuring noise in user input. We provide tools to augment existing SQuAD-like datasets by perturbing the examples along a number of different dimensions, including synonym replacement, misspelling, repetition and deletion.
-
Workflow Automation to prototype robust NLP models faster for research and production. This package is structured for extremely easy use and deployment. Using Prefect Flows, training, evaluation, and model selection can be executed in a single line of code, enabling faster iteration and easier itergration of research into production pipelines.
Features
Adversarial Training
Our implementation is based on the smoothness-inducing regularization approach proposed here. We have updated the implementation for fine-tuning on question-answer datasets, and added additional features like adversarial hyperparameter scheduling, and support for mixed-precision training.
Adversarial Attack
A key measure of robustness in neural networks is the so-called white-box adversarial attack. In the context of Transformer-based Question-Answer models, this attack seeks to inject noise into the model's input embeddings and assess performance on the original labels. Here, we implement the projected gradient descent (PGD) attack mechanism, bounded by the norm-ball. Metrics can be calculated for non-adversarial and adversarial evaluation, making robustness studies more streamlined and accessible.
Data Augmentation
The following perturbation methods are available to augment SQuAD-like data:
- Synonym Replacement (SR) via 1) constrained word2vec, and 2) MLM using BERT
- (original) How many species of plants were [recorded] in Egypt?
+ (augmented) How many species of plants were [registered] in Egypt?
- Random Deletion (RD) using entity-aware term selection
- (original) How many species of plants [were] recorded in Egypt?
+ (augmented) How many species of plants [] recorded in Egypt?
- Random Repetition (RR) using entity-aware term selection
- (original) How many species of plants [were] recorded in Egypt?
+ (augmented) How many species of plants [were were] recorded in Egypt?
- (original) How [many] species of plants were recorded in Egypt?
+ (augmented) How [mony] species of plants were recorded in Egypt?
Perturbation types can be flexibly applied in combination with different frequencies for fine-grained control of natural noise profiles
- (original) How [many] species [of] plants [were] recorded in Egypt?
+ (augmented) How [mony] species [] plants [] recorded in Egypt?
Each perturbation type also supports custom term importance sampling, e.g. as generated using a MLM
(How, 0.179), (many, 0.254), (species, 0.123), (of, 0.03), (plants, 0.136) (were, 0.039), (recorded, 0.067), (in, 0.012), (Egypt, 0.159)
ML Flows
Using the Prefect library, KitanaQA makes it increadibly easy to combine different workflows for end-to-end training/evaluation/model selection. This system also supports rapid iteration in hyperparameter search by easily specifying each experimental condition and deploying independently. You can even get results reported directly in Slack!
Installation
Entity-aware data augmentations make use of the John Snow Labs spark-nlp library, which requires pyspark. To enable this feature, make sure Java v8 is set by default for pyspark compatibility:
sudo apt install openjdk-8-jdk
sudo update-alternatives --config java
java -version
It is recommended that you use a virtual environment when installing from pip or source. Virtualenv and Conda are good options.
This package has been tested on Python 3.7+, PyTorch 1.5.1+ and transformers 1.3.1
Install with pip:
pip install kitanaqa
Install from source:
git clone https://github.com/searchableai/KitanaQA.git
cd KitanaQA
python setup.py install
Getting Started
To run training or evaluation from the commandline:
python src/kitanaqa/trainer/run_pipeline.py --args=args.json
See an example args.json
Examples
Augmentation
Training and Evaluation
Models Supported
We make use of the following models and their respective tokenizers and configurations provided by Hugging Face Inc.
- ALBERT
- BERT
- DistilBERT
Contributing to KitanaQA
We welcome suggestions and contributions! Submit an issue or pull request and we will do our best to respond in a timely manner. See CONTRIBUTING.md for detailed information on contributing.
Thanks!
- John Snow Labs
- Hugging Face Inc.
- pytorch community
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
File details
Details for the file kitanaqa-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: kitanaqa-0.1.0-py3-none-any.whl
- Upload date:
- Size: 74.3 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.22.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.40.2 CPython/3.7.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3f009ad6be97f942861e4ddc325c483308c56e3eb78fee4f12d55066271d3368 |
|
MD5 | 1f33dafa4349588ff0d9f0750b40ce36 |
|
BLAKE2b-256 | ee051386aae7ac4f810d3dedb915646ca4f09fdfe14eb2cefc4ac82bea98d54e |