SimplEx - Explaining Latent Representations with a Corpus of Examples
Project description
SimplEx - Explaining Latent Representations with a Corpus of Examples
Code Author: Jonathan Crabbé (jc2133@cam.ac.uk)
This repository contains the implementation of SimplEx, a method to explain the latent representations of black-box models with the help of a corpus of examples. For more details, please read our NeurIPS 2021 paper: 'Explaining Latent Representations with a Corpus of Examples'.
:rocket: Installation
The library can be installed from PyPI using
$ pip install simplexai
or from source, using
$ pip install .
Toy example
Bellow, you can find a toy demonstration where we make a corpus decomposition of test examples representations. All the relevant code can be found in the file simplex.
from simplexai.explainers.simplex import Simplex
from simplexai.models.image_recognition import MnistClassifier
from simplexai.experiments.mnist import load_mnist
# Get a model
model = MnistClassifier() # Model should have the BlackBox interface
# Load corpus and test inputs
corpus_loader = load_mnist(subset_size=100, train=True, batch_size=100) # MNIST train loader
test_loader = load_mnist(subset_size=10, train=True, batch_size=10) # MNIST test loader
corpus_inputs, _ = next(iter(corpus_loader)) # A tensor of corpus inputs
test_inputs, _ = next(iter(test_loader)) # A set of inputs to explain
# Compute the corpus and test latent representations
corpus_latents = model.latent_representation(corpus_inputs).detach()
test_latents = model.latent_representation(test_inputs).detach()
# Initialize SimplEX, fit it on test examples
simplex = Simplex(corpus_examples=corpus_inputs,
corpus_latent_reps=corpus_latents)
simplex.fit(test_examples=test_inputs,
test_latent_reps=test_latents,
reg_factor=0)
# Get the weights of each corpus decomposition
weights = simplex.weights
We get a tensor weights that can be interpreted as follows:
weights[i,c] = weight of corpus example c in the decomposition of example i
.
We can get the importance of each corpus feature for the decomposition
of a given example i
in the following way:
import torch
# Compute the Integrated Jacobian for a particular example
i = 4
input_baseline = torch.zeros(corpus_inputs.shape) # Baseline tensor of the same shape as corpus_inputs
simplex.jacobian_projection(test_id=i, model=model, input_baseline=input_baseline)
result = simplex.decompose(i)
We get a list result
where each element of the list corresponds to a corpus example.
This list is sorted by decreasing order of importance in the corpus decomposition.
Each element of the list is a tuple structured as follows:
w_c, x_c, proj_jacobian_c = result[c]
Where w_c
corresponds to the weight weights[i,c]
, x_c
corresponds to corpus_inputs[c]
and proj_jacobian
is a tensor such that proj_jacobian_c[k]
is the Projected Jacobian
of feature k
from corpus example c
.
Reproducing the paper results
Reproducing MNIST Approximation Quality Experiment
- Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m simplexai.experiments.mnist -experiment "approximation_quality" -cv CV
- Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.mnist.quality.plot_results -cv_list CV1 CV2 CV3 ...
- The resulting plots and data are saved here.
Reproducing Prostate Cancer Approximation Quality Experiment
This experiment requires the access to the private datasets CUTRACT and SEER decribed in the paper.
- Copy the files
cutract_internal_all.csv
andseer_external_imputed_new.csv
are in the folderdata/Prostate Cancer
- Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m simplexai.experiments.prostate_cancer -experiment "approximation_quality" -cv CV
- Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.prostate.quality.plot_results -cv_list CV1 CV2 CV3 ...
- The resulting plots are saved here.
Reproducing Prostate Cancer Outlier Experiment
This experiment requires the access to the private datasets CUTRACT and SEER decribed in the paper.
- Make sure that the files
cutract_internal_all.csv
andseer_external_imputed_new.csv
are in the folderdata/Prostate Cancer
- Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m simplexai.experiments.prostate_cancer -experiment "outlier_detection" -cv CV
- Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.prostate.outlier.plot_results -cv_list CV1 CV2 CV3 ...
- The resulting plots are saved here.
Reproducing MNIST Jacobian Projection Significance Experiment
- Run the following script
python -m simplexai.experiments.mnist -experiment "jacobian_corruption"
2.The resulting plots and data are saved here.
Reproducing MNIST Outlier Detection Experiment
- Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m simplexai.experiments.mnist -experiment "outlier_detection" -cv CV
- Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.mnist.outlier.plot_results -cv_list CV1 CV2 CV3 ...
- The resulting plots and data are saved here.
Reproducing MNIST Influence Function Experiment
- Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m simplexai.experiments.mnist -experiment "influence" -cv CV
- Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.mnist.influence.plot_results -cv_list CV1 CV2 CV3 ...
- The resulting plots and data are saved here.
Note: some problems can appear with the package
Pytorch Influence Functions.
If this is the case, please change calc_influence_function.py
in the following way:
343: influences.append(tmp_influence) ==> influences.append(tmp_influence.cpu())
438: influences_meta['test_sample_index_list'] = sample_list ==> #influences_meta['test_sample_index_list'] = sample_list
Reproducing AR Approximation Quality Experiment
- Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m simplexai.experiments.time_series -experiment "approximation_quality" -cv CV
- Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.ar.quality.plot_results -cv_list CV1 CV2 CV3 ...
- The resulting plots and data are saved here.
Reproducing AR Outlier Detection Experiment
- Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m simplexai.experiments.time_series -experiment "outlier_detection" -cv CV
- Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.ar.outlier.plot_results -cv_list CV1 CV2 CV3 ...
- The resulting plots and data are saved here.
:hammer: Tests
Install the testing dependencies using
pip install .[testing]
The tests can be executed using
pytest -vsx
Citing
If you use this code, please cite the associated paper:
@inproceedings{Crabbe2021Simplex,
author = {Crabbe, Jonathan and Qian, Zhaozhi and Imrie, Fergus and van der Schaar, Mihaela},
booktitle = {Advances in Neural Information Processing Systems},
editor = {M. Ranzato and A. Beygelzimer and Y. Dauphin and P.S. Liang and J. Wortman Vaughan},
pages = {12154--12166},
publisher = {Curran Associates, Inc.},
title = {Explaining Latent Representations with a Corpus of Examples},
url = {https://proceedings.neurips.cc/paper/2021/file/65658fde58ab3c2b6e5132a39fae7cb9-Paper.pdf},
volume = {34},
year = {2021}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distributions
File details
Details for the file simplexai-0.0.3-py3-none-macosx_10_14_x86_64.whl
.
File metadata
- Download URL: simplexai-0.0.3-py3-none-macosx_10_14_x86_64.whl
- Upload date:
- Size: 52.5 kB
- Tags: Python 3, macOS 10.14+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.16
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f1276f6c70bf30347a38e8894d686b2a998e74a941973eecaa9c271a7befb585 |
|
MD5 | 4e199580947a698ba65aadbab8b0353e |
|
BLAKE2b-256 | 31b5cb3de02db70c321f6813e2ebff9a88ac0d31193b0319d1f698ea06af7c0f |
File details
Details for the file simplexai-0.0.3-py3-none-any.whl
.
File metadata
- Download URL: simplexai-0.0.3-py3-none-any.whl
- Upload date:
- Size: 52.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 513fb49a0b78ea60eeb989b4ed4f7808d98612c2816e1a4bbb67863ab953a6d9 |
|
MD5 | a7a889d17cc25dbe2b02c06aced11436 |
|
BLAKE2b-256 | 2355bacfc067b20c90ece9ae7f27cf9210b602f2f5787581b7169e8e06609abb |