KernelSHAP adaptation for recurrent models.
Project description
TimeSHAP
TimeSHAP is a model-agnostic, recurrent explainer that builds upon KernelSHAP and extends it to the sequential domain. TimeSHAP computes event/timestamp- feature-, and cell-level attributions. As sequences can be arbitrarily long, TimeSHAP also implements a pruning algorithm based on Shapley Values, that finds a subset of consecutive, recent events that contribute the most to the decision.
This repository is the code implementation of the TimeSHAP algorithm
present in the paper TimeSHAP: Explaining Recurrent Models through Sequence Perturbations
published at KDD 2021.
Links to the paper here, and to the video presentation here.
Install TimeSHAP
Via Pip
pip install timeshap
Via Github
Clone the repository into a local directory using:
git clone https://github.com/feedzai/timeshap.git
Move into the cloned repo and install the package:
cd timeshap
pip install .
Test your installation
Start a Python session in your terminal using
python
And import TimeSHAP
import timeshap
TimeSHAP in 30 seconds
Inputs
- Model being explained;
- Instance(s) to explain;
- Background instance.
Outputs
- Local pruning output; (explaining a single instance)
- Local event explanations; (explaining a single instance)
- Local feature explanations; (explaining a single instance)
- Global pruning statistics; (explaining multiple instances)
- Global event explanations; (explaining multiple instances)
- Global feature explanations; (explaining multiple instances)
Model Interface
In order for TimeSHAP to explain a model, an entry point must be provided.
This Callable
entry point must receive a 3-D numpy array, (#sequences; #sequence length; #features)
and return a 2-D numpy array (#sequences; 1)
with the corresponding score of each sequence.
In addition, to make TimeSHAP more optimized, it is possible to return the hidden state of the model together with the score (if applicable). Although this is optional, we highly recommended it, as it has a very high impact. If you choose to return the hidden state, this hidden state should either be: (see notebook for specific examples)
- a 3-D numpy array,
(#rnn layers, #sequences, #hidden_dimension)
(classExplainedRNN
on notebook); - a tuple of numpy arrays that follows the previously described characteristic
(usually used when using stacked RNNs with different hidden dimensions) (class
ExplainedGRU2Layer
on notebook); - a tuple of tuples of numpy arrays (usually used when using LSTM's) (class
ExplainedLSTM
on notebook);; TimeSHAP is able to explain any black-box model as long as it complies with the previously described interface, including both PyTorch and TensorFlow models, both examplified in our tutorials (PyTorch, TensorFlow).
Example provided in our tutorials:
- TensorFLow
model = tf.keras.models.Model(inputs=inputs, outputs=ff2)
f = lambda x: model.predict(x)
- Pytorch - (Example where model receives and returns hidden states)
model_wrapped = TorchModelWrapper(model)
f_hs = lambda x, y=None: model_wrapped.predict_last_hs(x, y)
Model Wrappers
In order to facilitate the interface between models and TimeSHAP,
TimeSHAP implements ModelWrappers
. These wrappers, used on the PyTorch
tutorial notebook, allow for greater flexibility
of explained models as they allow:
- Batching logic: useful when using very large inputs or NSamples, which cannot fit on GPU memory, and therefore batching mechanisms are required;
- Input format/type: useful when your model does not work with numpy arrays. This is the case of our provided PyToch example;
- Hidden state logic: useful when the hidden states of your models do not match the hidden state format required by TimeSHAP
TimeSHAP Explanation Methods
TimeSHAP offers several methods to use depending on the desired explanations. Local methods provide detailed view of a model decision corresponding to a specific sequence being explained. Global methods aggregate local explanations of a given dataset to present a global view of the model.
Local Explanations
Pruning
local_pruning()
performs the pruning
algorithm on a given sequence with a given user defined tolerance and returns
the pruning index along the information for plotting.
plot_temp_coalition_pruning()
plots the pruning
algorithm information calculated by local_pruning()
.
Event level explanations
local_event()
calculates event level explanations
of a given sequence with the user-given parameteres and returns the respective
event-level explanations.
plot_event_heatmap()
plots the event-level explanations
calculated by local_event()
.
Feature level explanations
local_feat()
calculates feature level explanations
of a given sequence with the user-given parameteres and returns the respective
feature-level explanations.
plot_feat_barplot()
plots the feature-level explanations
calculated by local_feat()
.
Cell level explanations
local_cell_level()
calculates cell level explanations
of a given sequence with the respective event- and feature-level explanations
and user-given parameteres, returing the respective cell-level explanations.
plot_cell_level()
plots the feature-level explanations
calculated by local_cell_level()
.
Local Report
local_report()
calculates TimeSHAP
local explanations for a given sequence and plots them.
Global Explanations
Global pruning statistics
prune_all()
performs the pruning
algorithm on multiple given sequences.
pruning_statistics()
calculates the pruning
statistics for several user-given pruning tolerances using the pruning
data calculated by prune_all()
, returning a pandas.DataFrame
with the statistics.
Global event level explanations
event_explain_all()
calculates TimeSHAP
event level explanations for multiple instances given user defined parameters.
plot_global_event()
plots the global event-level explanations
calculated by event_explain_all()
.
Global feature level explanations
feat_explain_all()
calculates TimeSHAP
feature level explanations for multiple instances given user defined parameters.
plot_global_feat()
plots the global feature-level
explanations calculated by feat_explain_all()
.
Global report
global_report()
calculates TimeSHAP
explanations for multiple instances, aggregating the explanations on two plots
and returning them.
Tutorial
In order to demonstrate TimeSHAP interfaces and methods, you can consult AReM.ipynb. In this tutorial we get an open-source dataset, process it, train Pytorch recurrent model with it and use TimeSHAP to explain it, showcasing all previously described methods.
Additionally, we also train a TensorFlow model on the same dataset AReM_TF.ipynb.
Repository Structure
notebooks
- tutorial notebooks demonstrating the package;src/timeshap
- the package source code;src/timeshap/explainer
- TimeSHAP methods to produce the explanationssrc/timeshap/explainer/kernel
- TimeSHAPKernelsrc/timeshap/plot
- TimeSHAP methods to produce explanation plotssrc/timeshap/utils
- util methods for TimeSHAP executionsrc/timeshap/wrappers
- Wrapper classes for models in order to ease TimeSHAP explanations
Citing TimeSHAP
@inproceedings{bento2021timeshap,
author = {Bento, Jo\~{a}o and Saleiro, Pedro and Cruz, Andr\'{e} F. and Figueiredo, M\'{a}rio A.T. and Bizarro, Pedro},
title = {TimeSHAP: Explaining Recurrent Models through Sequence Perturbations},
year = {2021},
isbn = {9781450383325},
publisher = {Association for Computing Machinery},
address = {New York, NY, USA},
url = {https://doi.org/10.1145/3447548.3467166},
doi = {10.1145/3447548.3467166},
booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining},
pages = {2565–2573},
numpages = {9},
keywords = {SHAP, Shapley values, TimeSHAP, XAI, RNN, explainability},
location = {Virtual Event, Singapore},
series = {KDD '21}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file timeshap-1.0.4.tar.gz
.
File metadata
- Download URL: timeshap-1.0.4.tar.gz
- Upload date:
- Size: 47.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a91cb82ef7eddd455bb01cd71695dfeef5d8b09c293e4055c5f00d379501f289 |
|
MD5 | 060cf342aebd400dd505311cb107602b |
|
BLAKE2b-256 | e4d3bcd4cd16b92202a456970a4854a88bf45e72d9bb12a48bd76ebbfe1daa00 |
File details
Details for the file timeshap-1.0.4-py3-none-any.whl
.
File metadata
- Download URL: timeshap-1.0.4-py3-none-any.whl
- Upload date:
- Size: 66.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 70c47020ecb3a3db0b3c64ef0a87e5c1861eb4156969fa4a4263b09bdc86c36c |
|
MD5 | 1c9b51afc10783c9fff6756c6d791862 |
|
BLAKE2b-256 | 3f810ff13bd2ba8677c5e31829ff129a92c2bdf06960015884352019d8dc6e80 |