SimplEx - Explaining Latent Representations with a Corpus of Examples
Project description
SimplEx - Explaining Latent Representations with a Corpus of Examples
Code Author: Jonathan Crabbé (jc2133@cam.ac.uk)
This repository contains the implementation of SimplEx, a method to explain the latent representations of black-box models with the help of a corpus of examples. For more details, please read our NeurIPS 2021 paper: 'Explaining Latent Representations with a Corpus of Examples'.
:rocket: Installation
The library can be installed from PyPI using
$ pip install simplexai
or from source, using
$ pip install .
Toy example
Bellow, you can find a toy demonstration where we make a corpus decomposition of test examples representations. All the relevant code can be found in the file simplex.
from simplexai.explainers.simplex import Simplex
from simplexai.models.image_recognition import MnistClassifier
from simplexai.experiments.mnist import load_mnist
# Get a model
model = MnistClassifier() # Model should have the BlackBox interface
# Load corpus and test inputs
corpus_loader = load_mnist(subset_size=100, train=True, batch_size=100) # MNIST train loader
test_loader = load_mnist(subset_size=10, train=True, batch_size=10) # MNIST test loader
corpus_inputs, _ = next(iter(corpus_loader)) # A tensor of corpus inputs
test_inputs, _ = next(iter(test_loader)) # A set of inputs to explain
# Compute the corpus and test latent representations
corpus_latents = model.latent_representation(corpus_inputs).detach()
test_latents = model.latent_representation(test_inputs).detach()
# Initialize SimplEX, fit it on test examples
simplex = Simplex(corpus_examples=corpus_inputs,
corpus_latent_reps=corpus_latents)
simplex.fit(test_examples=test_inputs,
test_latent_reps=test_latents,
reg_factor=0)
# Get the weights of each corpus decomposition
weights = simplex.weights
We get a tensor weights that can be interpreted as follows:
weights[i,c] = weight of corpus example c in the decomposition of example i
.
We can get the importance of each corpus feature for the decomposition
of a given example i
in the following way:
import torch
# Compute the Integrated Jacobian for a particular example
i = 4
input_baseline = torch.zeros(corpus_inputs.shape) # Baseline tensor of the same shape as corpus_inputs
simplex.jacobian_projection(test_id=i, model=model, input_baseline=input_baseline)
result = simplex.decompose(i)
We get a list result
where each element of the list corresponds to a corpus example.
This list is sorted by decreasing order of importance in the corpus decomposition.
Each element of the list is a tuple structured as follows:
w_c, x_c, proj_jacobian_c = result[c]
Where w_c
corresponds to the weight weights[i,c]
, x_c
corresponds to corpus_inputs[c]
and proj_jacobian
is a tensor such that proj_jacobian_c[k]
is the Projected Jacobian
of feature k
from corpus example c
.
Reproducing the paper results
Reproducing MNIST Approximation Quality Experiment
- Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m simplexai.experiments.mnist -experiment "approximation_quality" -cv CV
- Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.mnist.quality.plot_results -cv_list CV1 CV2 CV3 ...
- The resulting plots and data are saved here.
Reproducing Prostate Cancer Approximation Quality Experiment
This experiment requires the access to the private datasets CUTRACT and SEER decribed in the paper.
- Copy the files
cutract_internal_all.csv
andseer_external_imputed_new.csv
are in the folderdata/Prostate Cancer
- Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m simplexai.experiments.prostate_cancer -experiment "approximation_quality" -cv CV
- Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.prostate.quality.plot_results -cv_list CV1 CV2 CV3 ...
- The resulting plots are saved here.
Reproducing Prostate Cancer Outlier Experiment
This experiment requires the access to the private datasets CUTRACT and SEER decribed in the paper.
- Make sure that the files
cutract_internal_all.csv
andseer_external_imputed_new.csv
are in the folderdata/Prostate Cancer
- Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m simplexai.experiments.prostate_cancer -experiment "outlier_detection" -cv CV
- Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.prostate.outlier.plot_results -cv_list CV1 CV2 CV3 ...
- The resulting plots are saved here.
Reproducing MNIST Jacobian Projection Significance Experiment
- Run the following script
python -m simplexai.experiments.mnist -experiment "jacobian_corruption"
2.The resulting plots and data are saved here.
Reproducing MNIST Outlier Detection Experiment
- Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m simplexai.experiments.mnist -experiment "outlier_detection" -cv CV
- Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.mnist.outlier.plot_results -cv_list CV1 CV2 CV3 ...
- The resulting plots and data are saved here.
Reproducing MNIST Influence Function Experiment
- Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m simplexai.experiments.mnist -experiment "influence" -cv CV
- Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.mnist.influence.plot_results -cv_list CV1 CV2 CV3 ...
- The resulting plots and data are saved here.
Note: some problems can appear with the package
Pytorch Influence Functions.
If this is the case, please change calc_influence_function.py
in the following way:
343: influences.append(tmp_influence) ==> influences.append(tmp_influence.cpu())
438: influences_meta['test_sample_index_list'] = sample_list ==> #influences_meta['test_sample_index_list'] = sample_list
Reproducing AR Approximation Quality Experiment
- Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m simplexai.experiments.time_series -experiment "approximation_quality" -cv CV
- Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.ar.quality.plot_results -cv_list CV1 CV2 CV3 ...
- The resulting plots and data are saved here.
Reproducing AR Outlier Detection Experiment
- Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m simplexai.experiments.time_series -experiment "outlier_detection" -cv CV
- Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.ar.outlier.plot_results -cv_list CV1 CV2 CV3 ...
- The resulting plots and data are saved here.
:hammer: Tests
Install the testing dependencies using
pip install .[testing]
The tests can be executed using
pytest -vsx
Citing
If you use this code, please cite the associated paper:
@inproceedings{Crabbe2021Simplex,
author = {Crabbe, Jonathan and Qian, Zhaozhi and Imrie, Fergus and van der Schaar, Mihaela},
booktitle = {Advances in Neural Information Processing Systems},
editor = {M. Ranzato and A. Beygelzimer and Y. Dauphin and P.S. Liang and J. Wortman Vaughan},
pages = {12154--12166},
publisher = {Curran Associates, Inc.},
title = {Explaining Latent Representations with a Corpus of Examples},
url = {https://proceedings.neurips.cc/paper/2021/file/65658fde58ab3c2b6e5132a39fae7cb9-Paper.pdf},
volume = {34},
year = {2021}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distributions
Hashes for simplexai-0.0.3-py3-none-macosx_10_14_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f1276f6c70bf30347a38e8894d686b2a998e74a941973eecaa9c271a7befb585 |
|
MD5 | 4e199580947a698ba65aadbab8b0353e |
|
BLAKE2b-256 | 31b5cb3de02db70c321f6813e2ebff9a88ac0d31193b0319d1f698ea06af7c0f |
Hashes for simplexai-0.0.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 513fb49a0b78ea60eeb989b4ed4f7808d98612c2816e1a4bbb67863ab953a6d9 |
|
MD5 | a7a889d17cc25dbe2b02c06aced11436 |
|
BLAKE2b-256 | 2355bacfc067b20c90ece9ae7f27cf9210b602f2f5787581b7169e8e06609abb |