Explanation techniques for Transformer-based architectures
Project description
.. -- mode: rst --
|pypi_version|_ |pypi_downloads|_
.. |pypi_version| image:: https://img.shields.io/pypi/v/explainable-transformers.svg .. _pypi_version: https://pypi.python.org/pypi/explainable-transformers/
.. |pypi_downloads| image:: https://pepy.tech/badge/explainable-transformers/month .. _pypi_downloads: https://pepy.tech/project/explainable-transformers
.. image:: artwork/cover.png :alt: Vision Transformers explanation
===== explainable-transformers
Explanation and interpretation techniques for Transformer-based architectures.
Installation
Requirements:
- opencv-python
- numpy
- torch
- tqdm
.. code:: bash
pip install explainable-transformers
Usage examples
Please, see notebook/ for complete examples on how to create representations for the explanations.
For Vision Transformers, use the VisionTransformerWrapper passing a Pytorch model.
.. code:: python
from transformers import ViTModel
# import explanator module
from explainable_transformers.image_explainer import VisionTransformerWrapper
# define the last layer for classification
class PreTrainedViT(nn.Module):
def __init__(self, vit_model, d_model, classes):
...
def forward(self, x):
...
# load the pre-trained model
pretrained_vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k',
add_pooling_layer=False, output_attentions=True)
model = PreTrainedViT(pretrained_vit_model, hidden_size=768, output_dim=10)
# create the ViT wrapper and register the layers
vit_wrapper = VisionTransformerWrapper(model, device, num_attn_layers=12)
vit_wrapper.register_hook()
# explain a prediction using .generate_visualization(img)
image = Image.open('images/dogbird.png')
processed_image = transform(image)
cat_exp, _ = vit_wrapper.generate_visualization(processed_image)
For Text Transformers, right now we need to know how the attention component is organized.
.. code:: python
# first the imports
from transformers import BertTokenizer, BertForSequenceClassification
from explainable_transformers.utils import *
from explainable_transformers import NLPTransformerWrapper
# for text, we provide the NLP wrapper
"""
We access the attention component like following:
- BERT or RoBERTa: '.encoder.layer.#.attention.self.dropout'
- XLNet: '.layer.#.rel_attn.dropout'
"""
nlp_wrapper = NLPTransformerWrapper(model, device, 12, 'bert', 'classifier', '.encoder.layer.#.attention.self.dropout')
nlp_wrapper.register_hook()
explanation = nlp_wrapper.generate_explanation(input_ids, attention_mask, class_index=true_class, start_layer=NUM_LAYERS-1)
explanation = explanation.detach().cpu().numpy()
Citation
Please, use the respective authors if you use any of the techniques.
Currently, we have the Pytorch implementation of the following approaches:
Transformer Interpretability Beyond Attention Visualization (paper <https://arxiv.org/abs/2012.09838>
_):
-
Transformers: BERT, RoBERTa, and XLNet
-
Vision Transformers
.. code:: bibtex
@InProceedings{Chefer_2021_CVPR,
author = {Chefer, Hila and Gur, Shir and Wolf, Lior},
title = {Transformer Interpretability Beyond Attention Visualization},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2021},
pages = {782-791}
}
License
explainable-transformers follows the 3-clause BSD license and it is based on other open-source implementations: Chefer's <https://github.com/hila-chefer/Transformer-Explainability>
_.
We also use nlp_understanding <https://github.com/ENSAE-CKW/nlp_understanding>
_ for generating the heatmap.
E-mail me (wilson_jr at outlook dot com) if you like to contribute.
......
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Hashes for explainable_transformers-0.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c2af011ef42f7ce506d5149d31931375776db901791a04c9b9ea18494768bd59 |
|
MD5 | 7a6e796249789c5e8f8a8c443ad71a7c |
|
BLAKE2b-256 | 3abf728ba269a096a6f8949a608db19006cf55cb574db822fcbf2218272f0d9c |