Skip to main content

A framework to explain the latent representations of unsupervised black-box models with the help of usual feature importance and example-based methods.

Project description

Label-Free XAI

Tests License: MIT Documentation Status

image

Code Author: Jonathan Crabbé (jc2133@cam.ac.uk)

This repository contains the implementation of LFXAI, a framework to explain the latent representations of unsupervised black-box models with the help of usual feature importance and example-based methods. For more details, please read our ICML 2022 paper: 'Label-Free Explainability for Unsupervised Models'.

1. Installation

From PyPI

pip install lfxai

From repository:

  1. Clone the repository
  2. Create a new virtual environment with Python 3.8
  3. Run the following command from the repository folder:
pip install .

When the packages are installed, you are ready to explain unsupervised models.

2. Toy example

Bellow, you can find a toy demonstration where we compute label-free feature and example importance with a MNIST autoencoder. The relevant code can be found in the folder explanations.

import torch
from pathlib import Path
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torch.nn import MSELoss
from captum.attr import IntegratedGradients

from lfxai.models.images import AutoEncoderMnist, EncoderMnist, DecoderMnist
from lfxai.models.pretext import Identity
from lfxai.explanations.features import attribute_auxiliary
from lfxai.explanations.examples import SimplEx

# Select torch device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Load data
data_dir = Path.cwd() / "data/mnist"
train_dataset = MNIST(data_dir, train=True, download=True)
test_dataset = MNIST(data_dir, train=False, download=True)
train_dataset.transform = transforms.Compose([transforms.ToTensor()])
test_dataset.transform = transforms.Compose([transforms.ToTensor()])
train_loader = DataLoader(train_dataset, batch_size=100)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)

# Get a model
encoder = EncoderMnist(encoded_space_dim=10)
decoder = DecoderMnist(encoded_space_dim=10)
model = AutoEncoderMnist(encoder, decoder, latent_dim=10, input_pert=Identity())
model.to(device)

# Get label-free feature importance
baseline = torch.zeros((1, 1, 28, 28)).to(device) # black image as baseline
attr_method = IntegratedGradients(model)
feature_importance = attribute_auxiliary(encoder, test_loader,
                                         device, attr_method, baseline)

# Get label-free example importance
train_subset = Subset(train_dataset, indices=list(range(500))) # Limit the number of training examples
train_subloader = DataLoader(train_subset, batch_size=500)
attr_method = SimplEx(model, loss_f=MSELoss())
example_importance = attr_method.attribute_loader(device, train_subloader, test_loader)

3. Reproducing the paper results

MNIST experiments

In the experiments folder, run the following script

python -m mnist --name experiment_name

where experiment_name can take the following values:

experiment_name description
consistency_features Consistency check for label-free
feature importance (paper Section 4.1)
consistency_examples Consistency check for label-free
example importance (paper Section 4.1)
roar_test ROAR test for label-free
feature importance (paper Appendix C)
pretext Pretext task sensitivity
use case (paper Section 4.2)
disvae Challenging assumptions with
disentangled VAEs (paper Section 4.3)

The resulting plots and data are saved here.

ECG5000 experiments

Run the following script

python -m ecg5000 --name experiment_name

where experiment_name can take the following values:

experiment_name description
consistency_features Consistency check for label-free
feature importance (paper Section 4.1)
consistency_examples Consistency check for label-free
example importance (paper Section 4.1)

The resulting plots and data are saved here.

CIFAR10 experiments

Run the following script

python -m cifar10

The experiment can be selected by changing the experiment_name parameter in this file. The parameter can take the following values:

experiment_name description
consistency_features Consistency check for label-free
feature importance (paper Section 4.1)
consistency_examples Consistency check for label-free
example importance (paper Section 4.1)

The resulting plots and data are saved here.

dSprites experiment

Run the following script

python -m dsprites

The experiment needs several hours to run since several VAEs are trained. The resulting plots and data are saved here.

4. Citing

If you use this code, please cite the associated paper:

@misc{Crabbe2022LFXAI,
  doi = {10.48550/ARXIV.2203.01928},
  url = {https://arxiv.org/abs/2203.01928},
  author = {Crabbé, Jonathan and van der Schaar, Mihaela},
  keywords = {Machine Learning (cs.LG), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},
  title = {Label-Free Explainability for Unsupervised Models},
  publisher = {arXiv},
  year = {2022},
  copyright = {Creative Commons Attribution 4.0 International}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

lfxai-0.1.1-py3-none-macosx_10_14_x86_64.whl (31.8 kB view details)

Uploaded Python 3 macOS 10.14+ x86-64

lfxai-0.1.1-py3-none-any.whl (32.0 kB view details)

Uploaded Python 3

File details

Details for the file lfxai-0.1.1-py3-none-macosx_10_14_x86_64.whl.

File metadata

File hashes

Hashes for lfxai-0.1.1-py3-none-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 81d7cff725fceb5d76fd4c6c4ebc9f3ad783031a1f8545d01dccd66807df00df
MD5 f80e939904cef1d718d21d1f8705d063
BLAKE2b-256 e8c4bc1b5829c29c87765c88161e4893a9d9284deb3f59c69a45a4f44b93e59e

See more details on using hashes here.

File details

Details for the file lfxai-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: lfxai-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 32.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for lfxai-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 bc7e3230b99cfd92247ea1a28ef74fb9f6cc6eb264dd49fb8c64fbfe762586f8
MD5 f70b8b8fe8f9f1e81978a3ac0c98e378
BLAKE2b-256 2ef26819060be5f701a918036e77623e75886b19b67ca372f8dcb290d4ee8f4b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page