A framework to explain the latent representations of unsupervised black-box models with the help of usual feature importance and example-based methods.
Project description
Label-Free XAI
Code Author: Jonathan Crabbé (jc2133@cam.ac.uk)
This repository contains the implementation of LFXAI, a framework to explain the latent representations of unsupervised black-box models with the help of usual feature importance and example-based methods. For more details, please read our ICML 2022 paper: 'Label-Free Explainability for Unsupervised Models'.
1. Installation
From PyPI
pip install lfxai
From repository:
- Clone the repository
- Create a new virtual environment with Python 3.8
- Run the following command from the repository folder:
pip install .
When the packages are installed, you are ready to explain unsupervised models.
2. Toy example
Bellow, you can find a toy demonstration where we compute label-free feature and example importance with a MNIST autoencoder. The relevant code can be found in the folder explanations.
import torch
from pathlib import Path
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torch.nn import MSELoss
from captum.attr import IntegratedGradients
from lfxai.models.images import AutoEncoderMnist, EncoderMnist, DecoderMnist
from lfxai.models.pretext import Identity
from lfxai.explanations.features import attribute_auxiliary
from lfxai.explanations.examples import SimplEx
# Select torch device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Load data
data_dir = Path.cwd() / "data/mnist"
train_dataset = MNIST(data_dir, train=True, download=True)
test_dataset = MNIST(data_dir, train=False, download=True)
train_dataset.transform = transforms.Compose([transforms.ToTensor()])
test_dataset.transform = transforms.Compose([transforms.ToTensor()])
train_loader = DataLoader(train_dataset, batch_size=100)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
# Get a model
encoder = EncoderMnist(encoded_space_dim=10)
decoder = DecoderMnist(encoded_space_dim=10)
model = AutoEncoderMnist(encoder, decoder, latent_dim=10, input_pert=Identity())
model.to(device)
# Get label-free feature importance
baseline = torch.zeros((1, 1, 28, 28)).to(device) # black image as baseline
attr_method = IntegratedGradients(model)
feature_importance = attribute_auxiliary(encoder, test_loader,
device, attr_method, baseline)
# Get label-free example importance
train_subset = Subset(train_dataset, indices=list(range(500))) # Limit the number of training examples
train_subloader = DataLoader(train_subset, batch_size=500)
attr_method = SimplEx(model, loss_f=MSELoss())
example_importance = attr_method.attribute_loader(device, train_subloader, test_loader)
3. Reproducing the paper results
MNIST experiments
In the experiments
folder, run the following script
python -m mnist --name experiment_name
where experiment_name can take the following values:
experiment_name | description |
---|---|
consistency_features | Consistency check for label-free feature importance (paper Section 4.1) |
consistency_examples | Consistency check for label-free example importance (paper Section 4.1) |
roar_test | ROAR test for label-free feature importance (paper Appendix C) |
pretext | Pretext task sensitivity use case (paper Section 4.2) |
disvae | Challenging assumptions with disentangled VAEs (paper Section 4.3) |
The resulting plots and data are saved here.
ECG5000 experiments
Run the following script
python -m ecg5000 --name experiment_name
where experiment_name can take the following values:
experiment_name | description |
---|---|
consistency_features | Consistency check for label-free feature importance (paper Section 4.1) |
consistency_examples | Consistency check for label-free example importance (paper Section 4.1) |
The resulting plots and data are saved here.
CIFAR10 experiments
Run the following script
python -m cifar10
The experiment can be selected by changing the experiment_name parameter in this file. The parameter can take the following values:
experiment_name | description |
---|---|
consistency_features | Consistency check for label-free feature importance (paper Section 4.1) |
consistency_examples | Consistency check for label-free example importance (paper Section 4.1) |
The resulting plots and data are saved here.
dSprites experiment
Run the following script
python -m dsprites
The experiment needs several hours to run since several VAEs are trained. The resulting plots and data are saved here.
4. Citing
If you use this code, please cite the associated paper:
@misc{Crabbe2022LFXAI,
doi = {10.48550/ARXIV.2203.01928},
url = {https://arxiv.org/abs/2203.01928},
author = {Crabbé, Jonathan and van der Schaar, Mihaela},
keywords = {Machine Learning (cs.LG), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Label-Free Explainability for Unsupervised Models},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution 4.0 International}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distributions
File details
Details for the file lfxai-0.1.1-py3-none-macosx_10_14_x86_64.whl
.
File metadata
- Download URL: lfxai-0.1.1-py3-none-macosx_10_14_x86_64.whl
- Upload date:
- Size: 31.8 kB
- Tags: Python 3, macOS 10.14+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 81d7cff725fceb5d76fd4c6c4ebc9f3ad783031a1f8545d01dccd66807df00df |
|
MD5 | f80e939904cef1d718d21d1f8705d063 |
|
BLAKE2b-256 | e8c4bc1b5829c29c87765c88161e4893a9d9284deb3f59c69a45a4f44b93e59e |
File details
Details for the file lfxai-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: lfxai-0.1.1-py3-none-any.whl
- Upload date:
- Size: 32.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | bc7e3230b99cfd92247ea1a28ef74fb9f6cc6eb264dd49fb8c64fbfe762586f8 |
|
MD5 | f70b8b8fe8f9f1e81978a3ac0c98e378 |
|
BLAKE2b-256 | 2ef26819060be5f701a918036e77623e75886b19b67ca372f8dcb290d4ee8f4b |