Skip to main content

PyTorch Implementation of Models used for Uncertainty Estimation in Natural Language Processing.

Project description

:robot::speech_balloon::question: nlp-uncertainty-zoo

This repository contains implementations of several models used for uncertainty estimation in Natural Language processing, implemented in PyTorch. You can install the repository using pip:

pip3 install nlp-uncertainty-zoo

If you are using the repository in your academic research, please cite the paper below:

@inproceedings{ulmer-etal-2022-exploring,
  title = "Exploring Predictive Uncertainty and Calibration in {NLP}: A Study on the Impact of Method {\&} Data Scarcity",
  author = "Ulmer, Dennis  and
    Frellsen, Jes  and
    Hardmeier, Christian",
  booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2022",
  month = dec,
  year = "2022",
  address = "Abu Dhabi, United Arab Emirates",
  publisher = "Association for Computational Linguistics",
  url = "https://aclanthology.org/2022.findings-emnlp.198",
  pages = "2707--2735",

}

To learn more about the package, consult the documentation here, check a Jupyter notebook demo here or a Google collab here.

Included models

The following models are implemented in the repository. They can all be imported by using from nlp-uncertainty-zoo import <MODEL>. For transformer-based model, furthermore a version of a model is available that uses a pre-trained BERT from the HuggingFace transformers.

Name Description Implementation Paper
LSTM Vanilla LSTM LSTM Hochreiter & Schmidhuber, 1997
LSTM Ensemble Ensemble of LSTMs LSTMEnsemble Lakshminarayanan et al., 2017
Bayesian LSTM LSTM implementing Bayes-by-backprop Blundell et al, 2015 BayesianLSTM Fortunato et al, 2017
ST-tau LSTM LSTM modelling transitions of a finite-state-automaton STTauLSTM Wang et al., 2021
Variational LSTM LSTM with MC Dropout (Gal & Ghahramani, 2016a) VariationalLSTM Gal & Ghahramani, 2016b
DDU Transformer, DDU BERT Transformer / BERT with Gaussian Mixture Model fit to hidden features DDUTransformer, DDUBert Mukhoti et al, 2021
Variational Transformer, Variational BERT Transformer / BERT with MC Dropout (Gal & Ghahramani, 2016a) VariationalTransformer, VariationalBert Xiao et al., 2021
DPP Transformer, DPP Bert Transformer / BERT using determinantal point process dropout DPPTransformer, DPPBert Shelmanov et al., 2021
SNGP Transformer, SNGP BERT Spectrally-normalized transformer / BERT using a Gaussian Process output layer SNGPTransformer, SNGPBert Liu et al., 2022

Contributions to include even more approaches are much appreciated!

Usage

Each model comes in two versions, for instance LSTMEnsemble and LSTMEnsembleModule. The first one is supposed to be used as an out-of-the-box solution, encapsulating all training logic and convenience functions. These include fitting the model, prediction, getting the uncertainty for an input batch using a specific metric.

model = LSTMEnsemble(**network_params, ensemble_size=10, is_sequence_classifer=False)
model.fit(train_split=train_dataloader)
model.get_logits(X)
model.get_predictions(X)
model.get_sequence_representation(X)
model.available_uncertainty_metrics
model.get_uncertainty(X)
model.get_uncertainty(X, metric_name="mutual_information")

In comparison, the -Module class is supposed to me more simple and bare-bones, only containing the core model logic. It is intended for research purposes, and for others who would like to embed the model into their own code base. While the model class (e.g. LSTMEnsemble) inherits from Model and would require to implement certain methods, any Module class sticks closely to torch.nn.Module.

To check what arguments are required to initialize and use different models, check the documentation here.

Also, check out the demo provided as a Jupyter notebook here or a Google collab here.

Repository structure

The repository has the following structure:

  • models: All model implementations.
  • tests: Unit tests. So far, only contains rudimentary tests to check that all output shapes are consistent between models and functions.
  • utils: Utility code (see below)
    • utils/custom_types.py: Custom types used in the repository for type annotations.
    • utils/data.py: Module containing data collators, and data builders - which build the dataloaders for a type of task and a specific dataset. Currently, language modelling, sequence labeling and sequence classification are supported.
    • utils/metrics.py: Implementations of uncertainty metrics.
    • utils/samplers.py: Dataset subsamplers for language modelling, sequence labelling and sequence classification.
    • utils/task_eval.py: Functions used to evaluate task performance.
    • utils/uncertainty_eval.py: Function used to evaluate uncertainty quality.
    • utils/calibration_eval.py: Function used to evaluate calibration quality.
  • config.py: Define available datasets, model and tasks.
  • defaults.py: Define default config parameters for sequence classification and language modelling (Note: These might not be very good parameters).

Other features

  • Weights & Biases integration: You can track your experiments easily with weights & biases by passing a wandb_run argument to model.fit()!
  • Easy fine-tuning via HuggingFace: You can fine-tune arbitrary BERT models using their name from HuggingFace's transformers.

Contributing

This repository is by no means perfect nor complete. If you find any bugs, please report them using the issue template, and, if you also happen to provide a fix, create a pull request! A GitHub template is provided for that as well.

You would like to make a new addition to the repository? Follow the steps below:

  • Adding a new model: To add a new model, add a new module in the models directory. You will also need to implement a corresponding Model and Module class, inheriting from the classes of the same name in models/model.py and implementing all required functions. Model is supposed to be an out-of-the-box solution that you can start experimenting right away, whil Module should only include the most basic model logic in order to be easy to integrate into other codebases and allow tinkering.

  • Adding a new uncertainty metric: To add a new uncertainty metric, add the function to utils/metrics.py. The function should take the logits of a model and output an uncertainty score (the higher the score, the more uncertain the model). The function should output a batch_size x sequence_length matrix, with batch_size x 1 for sequence classification tasks. After finishing the implementation, you can add the metric to the single_prediction_uncertainty_metrics of the models.model.Model class and multi_prediction_uncertainty_metrics of models.model.MultiPredictionMixin (if applicable).

You would like to add something else? Create an issue or contact me at dennis {dot} ulmer {at} mailbox {dot} org!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nlp-uncertainty-zoo-1.0.3.tar.gz (56.3 kB view details)

Uploaded Source

Built Distribution

nlp_uncertainty_zoo-1.0.3-py3-none-any.whl (76.4 kB view details)

Uploaded Python 3

File details

Details for the file nlp-uncertainty-zoo-1.0.3.tar.gz.

File metadata

  • Download URL: nlp-uncertainty-zoo-1.0.3.tar.gz
  • Upload date:
  • Size: 56.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.11.2 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.9.10

File hashes

Hashes for nlp-uncertainty-zoo-1.0.3.tar.gz
Algorithm Hash digest
SHA256 372e6b6d324beffe9a68ca0eaacba9e5d70b2ee0c5de6463d194eb389ebd7a2e
MD5 debabd7b2570e5107d21d710d0af372d
BLAKE2b-256 10d40adbbb2c4d9383026ba1f073d90f8a521a87583d9b248aecd91ff88c278e

See more details on using hashes here.

File details

Details for the file nlp_uncertainty_zoo-1.0.3-py3-none-any.whl.

File metadata

  • Download URL: nlp_uncertainty_zoo-1.0.3-py3-none-any.whl
  • Upload date:
  • Size: 76.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.11.2 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.9.10

File hashes

Hashes for nlp_uncertainty_zoo-1.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 63825ac61c4e659db9eed56c74ffa5405e34dfeeb04473c0b38018ffeac052ba
MD5 0479203eb77e814e7a981fe1a4c5042b
BLAKE2b-256 93bb3bdb0df0aa20209e276520ce4a401d8a588324b625cde6ff1dc1ddf519fe

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page