Skip to main content

Explanation techniques for Transformer-based architectures

Project description

.. -- mode: rst --

|pypi_version|_ |pypi_downloads|_

.. |pypi_version| image:: https://img.shields.io/pypi/v/explainable-transformers.svg .. _pypi_version: https://pypi.python.org/pypi/explainable-transformers/

.. |pypi_downloads| image:: https://pepy.tech/badge/explainable-transformers/month .. _pypi_downloads: https://pepy.tech/project/explainable-transformers

.. image:: artwork/cover.png :alt: Vision Transformers explanation

===== explainable-transformers

Explanation and interpretation techniques for Transformer-based architectures.


Installation

Requirements:

  • opencv-python
  • numpy
  • torch
  • tqdm

.. code:: bash

pip install explainable-transformers

Usage examples

Please, see notebook/ for complete examples on how to create representations for the explanations.

For Vision Transformers, use the VisionTransformerWrapper passing a Pytorch model.

.. code:: python

from transformers import ViTModel

# import explanator module
from explainable_transformers.image_explainer import VisionTransformerWrapper

# define the last layer for classification
class PreTrainedViT(nn.Module):
    def __init__(self, vit_model, d_model, classes):
        ...

    def forward(self, x):
        ...

        
# load the pre-trained model
pretrained_vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k', 
                                                add_pooling_layer=False, output_attentions=True)

model = PreTrainedViT(pretrained_vit_model, hidden_size=768, output_dim=10)

# create the ViT wrapper and register the layers
vit_wrapper = VisionTransformerWrapper(model, device, num_attn_layers=12)
vit_wrapper.register_hook()

# explain a prediction using .generate_visualization(img)
image = Image.open('images/dogbird.png')
processed_image = transform(image)
cat_exp, _ = vit_wrapper.generate_visualization(processed_image)

For Text Transformers, right now we need to know how the attention component is organized.

.. code:: python

# first the imports

from transformers import BertTokenizer, BertForSequenceClassification

from explainable_transformers.utils import *
from explainable_transformers import NLPTransformerWrapper


# for text, we provide the NLP wrapper

"""
We access the attention component like following:

- BERT or RoBERTa: '.encoder.layer.#.attention.self.dropout'
- XLNet: '.layer.#.rel_attn.dropout'

"""
nlp_wrapper = NLPTransformerWrapper(model, device, 12, 'bert', 'classifier', '.encoder.layer.#.attention.self.dropout')
nlp_wrapper.register_hook()

explanation = nlp_wrapper.generate_explanation(input_ids, attention_mask, class_index=true_class, start_layer=NUM_LAYERS-1)
explanation = explanation.detach().cpu().numpy()

Citation

Please, use the respective authors if you use any of the techniques.

Currently, we have the Pytorch implementation of the following approaches:

Transformer Interpretability Beyond Attention Visualization (paper <https://arxiv.org/abs/2012.09838>_):

  1. Transformers: BERT, RoBERTa, and XLNet

  2. Vision Transformers

.. code:: bibtex

@InProceedings{Chefer_2021_CVPR,
    author    = {Chefer, Hila and Gur, Shir and Wolf, Lior},
    title     = {Transformer Interpretability Beyond Attention Visualization},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2021},
    pages     = {782-791}
}

License

explainable-transformers follows the 3-clause BSD license and it is based on other open-source implementations: Chefer's <https://github.com/hila-chefer/Transformer-Explainability>_.

We also use nlp_understanding <https://github.com/ENSAE-CKW/nlp_understanding>_ for generating the heatmap.

E-mail me (wilson_jr at outlook dot com) if you like to contribute.

......

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

explainable_transformers-0.0.1-py3-none-any.whl (13.7 kB view details)

Uploaded Python 3

File details

Details for the file explainable_transformers-0.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for explainable_transformers-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 c2af011ef42f7ce506d5149d31931375776db901791a04c9b9ea18494768bd59
MD5 7a6e796249789c5e8f8a8c443ad71a7c
BLAKE2b-256 3abf728ba269a096a6f8949a608db19006cf55cb574db822fcbf2218272f0d9c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page