Skip to main content

Explain your 🤗 transformers without effort! Display the internal behavior of your model.

Project description

Transformers visualizer

Explain your 🤗 transformers without effort!

Apache PyPI - Python Version PyPI - Package Version

Transformers visualizer is a python package designed to work with the 🤗 transformers package. Given a model and a tokenizer, this package supports multiple ways to explain your model by plotting its internal behavior.

This package is mostly based on the Captum tutorials [1] [2].

Installation

pip install transformers-visualizer

Quickstart

Let's define a model, a tokenizer and a text input for the following examples.

from transformers import AutoModel, AutoTokenizer

model_name = "bert-base-uncased"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks that include an encoder and a decoder."

Visualizers

Attention matrices of a specific layer

from transformers_visualizer import TokenToTokenAttentions

visualizer = TokenToTokenAttentions(model, tokenizer)
visualizer(text)

Instead of using __call__ function, you can use the compute method. Both work in place, compute method allows chaining method.

plot method accept a layer index as parameter to specify which part of your model you want to plot. By default, the last layer is plotted.

import matplotlib.pyplot as plt

visualizer.plot(layer_index = 6)
plt.savefig("token_to_token.jpg")

token to token

Attention matrices normalized across head axis

You can specify the order used in torch.linalg.norm in __call__ and compute methods. By default, an L2 norm is applied.

from transformers_visualizer import TokenToTokenNormalizedAttentions

visualizer = TokenToTokenNormalizedAttentions(model, tokenizer)
visualizer.compute(text).plot()

normalized token to token

Plotting

plot method accept to skip special tokens with the parameter skip_special_tokens, by default it's set to False.

You can use the following imports to use plotting functions directly.

from transformers_visualizer.plotting import plot_token_to_token, plot_token_to_token_specific_dimension

These functions or the plot method of a visualizer can use the following parameters.

  • figsize (Tuple[int, int]): Figsize of the plot. Defaults to (20, 20).
  • ticks_fontsize (int): Ticks fontsize. Defaults to 7.
  • title_fontsize (int): Title fontsize. Defaults to 9.
  • cmap (str): Colormap. Defaults to "viridis".
  • colorbar (bool): Display colorbars. Defaults to True.

Upcoming features

  • Add an option to mask special tokens.
  • Add an option to specify head/layer indices to plot.
  • Add other plotting backends such as Plotly, Bokeh, Altair.
  • Implement other visualizers such as vector norm.

References

  • [1] Captum's BERT tutorial (part 1)
  • [2] Captum's BERT tutorial (part 2)

Acknowledgements

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformers_visualizer-0.2.2.tar.gz (12.9 kB view hashes)

Uploaded Source

Built Distribution

transformers_visualizer-0.2.2-py3-none-any.whl (13.2 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page