Explanation techniques for Transformer-based architectures
Project description
.. -- mode: rst --
|pypi_version|_ |pypi_downloads|_
.. |pypi_version| image:: https://img.shields.io/pypi/v/explainable-transformers.svg .. _pypi_version: https://pypi.python.org/pypi/explainable-transformers/
.. |pypi_downloads| image:: https://pepy.tech/badge/explainable-transformers/month .. _pypi_downloads: https://pepy.tech/project/explainable-transformers
.. image:: artwork/cover.png :alt: Vision Transformers explanation
===== explainable-transformers
Explanation and interpretation techniques for Transformer-based architectures.
Installation
Requirements:
- opencv-python
- numpy
- torch
- tqdm
.. code:: bash
pip install explainable-transformers
Usage examples
Please, see notebook/ for complete examples on how to create representations for the explanations.
For Vision Transformers, use the VisionTransformerWrapper passing a Pytorch model.
.. code:: python
from transformers import ViTModel
# import explanator module
from explainable_transformers.image_explainer import VisionTransformerWrapper
# define the last layer for classification
class PreTrainedViT(nn.Module):
def __init__(self, vit_model, d_model, classes):
...
def forward(self, x):
...
# load the pre-trained model
pretrained_vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k',
add_pooling_layer=False, output_attentions=True)
model = PreTrainedViT(pretrained_vit_model, hidden_size=768, output_dim=10)
# create the ViT wrapper and register the layers
vit_wrapper = VisionTransformerWrapper(model, device, num_attn_layers=12)
vit_wrapper.register_hook()
# explain a prediction using .generate_visualization(img)
image = Image.open('images/dogbird.png')
processed_image = transform(image)
cat_exp, _ = vit_wrapper.generate_visualization(processed_image)
For Text Transformers, right now we need to know how the attention component is organized.
.. code:: python
# first the imports
from transformers import BertTokenizer, BertForSequenceClassification
from explainable_transformers.utils import *
from explainable_transformers import NLPTransformerWrapper
# for text, we provide the NLP wrapper
"""
We access the attention component like following:
- BERT or RoBERTa: '.encoder.layer.#.attention.self.dropout'
- XLNet: '.layer.#.rel_attn.dropout'
"""
nlp_wrapper = NLPTransformerWrapper(model, device, 12, 'bert', 'classifier', '.encoder.layer.#.attention.self.dropout')
nlp_wrapper.register_hook()
explanation = nlp_wrapper.generate_explanation(input_ids, attention_mask, class_index=true_class, start_layer=NUM_LAYERS-1)
explanation = explanation.detach().cpu().numpy()
Citation
Please, use the respective authors if you use any of the techniques.
Currently, we have the Pytorch implementation of the following approaches:
Transformer Interpretability Beyond Attention Visualization (paper <https://arxiv.org/abs/2012.09838>
_):
-
Transformers: BERT, RoBERTa, and XLNet
-
Vision Transformers
.. code:: bibtex
@InProceedings{Chefer_2021_CVPR,
author = {Chefer, Hila and Gur, Shir and Wolf, Lior},
title = {Transformer Interpretability Beyond Attention Visualization},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2021},
pages = {782-791}
}
License
explainable-transformers follows the 3-clause BSD license and it is based on other open-source implementations: Chefer's <https://github.com/hila-chefer/Transformer-Explainability>
_.
We also use nlp_understanding <https://github.com/ENSAE-CKW/nlp_understanding>
_ for generating the heatmap.
E-mail me (wilson_jr at outlook dot com) if you like to contribute.
......
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
File details
Details for the file explainable_transformers-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: explainable_transformers-0.0.1-py3-none-any.whl
- Upload date:
- Size: 13.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c2af011ef42f7ce506d5149d31931375776db901791a04c9b9ea18494768bd59 |
|
MD5 | 7a6e796249789c5e8f8a8c443ad71a7c |
|
BLAKE2b-256 | 3abf728ba269a096a6f8949a608db19006cf55cb574db822fcbf2218272f0d9c |