Skip to main content

Adversarial Training and Data Augmentation for Neural Question-Answering Models

Project description

 

KitanaQA

Tool[KIT] for [A]dversarial Trai[N]ing and [A]ugmentation in [Q]uestion [A]nswering

AboutFeaturesInstallGetting StartedExamples

CircleCI

About

KitanaQA is an adversarial training and data augmentation framework for fine-tuning Transformer-based language models on question-answering datasets

Why KitanaQA?

While NLP models have made incredible progress on curated question-answer datasets in recent years, they are still brittle and unpredictable in production environments, making productization and enterprise adoption problematic. KitanaQA provides resources to "robustify" Transformer-based question-answer models against many types of natural and synthetic noise. The major features are:

  1. Adversarial Training can increase both robustness and performance of fine-tuned Transformer QA models. Here, we implement virtual adversarial training, which introduces embedding-space perturbations during fine-tuning to encourage the model to produce more stable results in the presence of noisy inputs.

Our experiments with BERT finetuned on the SQuAD v1.1 question answering dataset show a marked improvement in f1 and em scores:

Model em f1
BERT-base 80.8 88.5
BERT-base (ALUM) 81.97 88.92
  1. Augment Your Dataset to increase model generalizability and robustness using token-level perturbations. While Adversarial Training provides some measure of robustness against bounded perturbations, Augmentation can accomodate a wide range of naturally-occuring noise in user input. We provide tools to augment existing SQuAD-like datasets by perturbing the examples along a number of different dimensions, including synonym replacement, misspelling, repetition and deletion.

  2. Workflow Automation to prototype robust NLP models faster for research and production. This package is structured for extremely easy use and deployment. Using Prefect Flows, training, evaluation, and model selection can be executed in a single line of code, enabling faster iteration and easier itergration of research into production pipelines.

Features

Adversarial Training

Our implementation is based on the smoothness-inducing regularization approach proposed here. We have updated the implementation for fine-tuning on question-answer datasets, and added additional features like adversarial hyperparameter scheduling, and support for mixed-precision training.

Adversarial Attack

A key measure of robustness in neural networks is the so-called white-box adversarial attack. In the context of Transformer-based Question-Answer models, this attack seeks to inject noise into the model's input embeddings and assess performance on the original labels. Here, we implement the projected gradient descent (PGD) attack mechanism, bounded by the norm-ball. Metrics can be calculated for non-adversarial and adversarial evaluation, making robustness studies more streamlined and accessible.

Data Augmentation

The following perturbation methods are available to augment SQuAD-like data:

- (original)  How many species of plants were [recorded] in Egypt?
+ (augmented) How many species of plants were [registered] in Egypt?
  • Random Deletion (RD) using entity-aware term selection
- (original)  How many species of plants [were] recorded in Egypt?
+ (augmented) How many species of plants [] recorded in Egypt?
  • Random Repetition (RR) using entity-aware term selection
- (original)  How many species of plants [were] recorded in Egypt?
+ (augmented) How many species of plants [were were] recorded in Egypt?
  • Random Misspelling (RM) using open-source common misspellings datasets -- sources: wiki, brikbeck
- (original)  How [many] species of plants were recorded in Egypt?
+ (augmented) How [mony] species of plants were recorded in Egypt?

Perturbation types can be flexibly applied in combination with different frequencies for fine-grained control of natural noise profiles

- (original)  How [many] species [of] plants [were] recorded in Egypt?
+ (augmented) How [mony] species [] plants [] recorded in Egypt?

Each perturbation type also supports custom term importance sampling, e.g. as generated using a MLM
(How, 0.179), (many, 0.254), (species, 0.123), (of, 0.03), (plants, 0.136) (were, 0.039), (recorded, 0.067), (in, 0.012), (Egypt, 0.159)

ML Flows

Using the Prefect library, KitanaQA makes it increadibly easy to combine different workflows for end-to-end training/evaluation/model selection. This system also supports rapid iteration in hyperparameter search by easily specifying each experimental condition and deploying independently. You can even get results reported directly in Slack!

Installation

Entity-aware data augmentations make use of the John Snow Labs spark-nlp library, which requires pyspark. To enable this feature, make sure Java v8 is set by default for pyspark compatibility:

sudo apt install openjdk-8-jdk
sudo update-alternatives --config java
java -version

It is recommended that you use a virtual environment when installing from pip or source. Virtualenv and Conda are good options.

This package has been tested on Python 3.7+, PyTorch 1.5.1+ and transformers 1.3.1

Install with pip:
pip install kitanaqa

Install from source:

git clone https://github.com/searchableai/KitanaQA.git
cd KitanaQA
python setup.py install

Getting Started

To run training or evaluation from the commandline:

python src/kitanaqa/trainer/run_pipeline.py --args=args.json

See an example args.json

Examples

Augmentation

Training and Evaluation

Models Supported

We make use of the following models and their respective tokenizers and configurations provided by Hugging Face Inc.

  • ALBERT
  • BERT
  • DistilBERT

Contributing to KitanaQA

We welcome suggestions and contributions! Submit an issue or pull request and we will do our best to respond in a timely manner. See CONTRIBUTING.md for detailed information on contributing.

Thanks!

  • John Snow Labs
  • Hugging Face Inc.
  • pytorch community

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

kitanaqa-0.1.0-py3-none-any.whl (74.3 MB view details)

Uploaded Python 3

File details

Details for the file kitanaqa-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: kitanaqa-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 74.3 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.22.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.40.2 CPython/3.7.0

File hashes

Hashes for kitanaqa-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3f009ad6be97f942861e4ddc325c483308c56e3eb78fee4f12d55066271d3368
MD5 1f33dafa4349588ff0d9f0750b40ce36
BLAKE2b-256 ee051386aae7ac4f810d3dedb915646ca4f09fdfe14eb2cefc4ac82bea98d54e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page