Skip to main content

Toolkit for quantitative evaluation of data attribution methods in PyTorch.

Project description

quanda

Interpretability toolkit for quantitative evaluation of data attribution methods in PyTorch.

py_versions mypy ruff codecov PyPI - License arXiv

quanda quanda is under active development. Note the release version to ensure reproducibility of your work. Contributions, bug reports, and feature requests are welcome.

📑 Shortcut to paper!

🐼 Library overview

Training data attribution (TDA) methods attribute model output on a specific test sample to the training dataset that it was trained on. They reveal the training datapoints responsible for the model's decisions. Existing methods achieve this by estimating the counterfactual effect of removing datapoints from the training set (Koh and Liang, 2017; Park et al., 2023; Bae et al., 2024) tracking the contributions of training points to the loss reduction throughout training (Pruthi et al., 2020), using interpretable surrogate models (Yeh et al., 2018) or finding training samples that are deemed similar to the test sample by the model (Caruana et. al, 1999; Hanawa et. al, 2021). In addition to model understanding, TDA has been used in a variety of applications such as debugging model behavior (Koh and Liang, 2017; Yeh et al., 2018; K and Søgaard, 2021; Guo et al., 2021), data summarization (Khanna et al., 2019; Marion et al., 2023; Yang et al., 2023), dataset selection (Engstrom et al., 2024; Chhabra et al., 2024), fact tracing (Akyurek et al., 2022) and machine unlearning (Warnecke et al., 2023).

Although there are various demonstrations of TDA’s potential for interpretability and practical applications, the critical question of how TDA methods should be effectively evaluated remains open. Several approaches have been proposed by the community, which can be categorized into three groups:

Ground TruthAs some of the methods are designed to approximate LOO effects, ground truth can often be computed for TDA evaluation. However, this counterfactual ground truth approach requires retraining the model multiple times on different subsets of the training data, which quickly becomes computationally expensive. Additionally, this ground truth is shown to be dominated by noise in practical deep learning settings, due to the inherent stochasticity of a typical training process (Basu et al., 2021; Nguyen et al., 2023).
Downstream Task EvaluatorsTo remedy the challenges associated with ground truth evaluation, the literature proposes to assess the utility of a TDA method within the context of an end-task, such as model debugging or data selection (Koh and Liang, 2017; Khanna et al., 2019; Karthikeyan et al., 2021).
HeuristicsFinally, the community also used heuristics (desirable properties or sanity checks) to evaluate the quality of TDA techniques. These include comparing the attributions of a trained model and a randomized model (Hanawa et al., 2021) and measuring the amount of overlap between the attributions for different test samples (Barshan et al., 2020).

quanda is designed to meet the need of a comprehensive and systematic evaluation framework, allowing practitioners and researchers to obtain a detailed view of the performance of TDA methods in various contexts.

Library Features

  • Unified TDA Interface: quanda provides a unified interface for various TDA methods, allowing users to easily switch between different methods.
  • Metrics: quanda provides a set of metrics to evaluate the effectiveness of TDA methods. These metrics are based on the latest research in the field.
  • Benchmarking: quanda provides a benchmarking tool to evaluate the performance of TDA methods on a given model, dataset and problem. As many TDA evaluation methods require access to ground truth, our benchmarking tools allow to generate a controlled setting with ground truth, and then compare the performance of different TDA methods on this setting.

Supported TDA Libraries

Library Reference
Captum (Similarity Influence, Arnoldi Influence Function, TracIn) Caruana et al., 1999; Schioppa et al., 2022; Koh and Liang, 2017; Pruthi et al., 2020
TRAK (TRAK) Park et al., 2023
Representer Point Selection (Representer Point Selection) Yeh et al., 2018
Kronfluence (Kronfluence) Grosse et al., 2023
Dattri (Influence Functions: Explicit / CG / LiSSA / DataInf, Arnoldi, EK-FAC, TracInCP, Grad-Dot, Grad-Cos, TRAK) Deng et al., 2024

Metrics

  • Linear Datamodeling Score (Park et al., 2023): Measures the correlation between the (grouped) attribution scores and the actual output of models trained on different subsets of the training set. For each subset, the linear datamodeling score compares the actual model output to the sum of attribution scores from the subset using Spearman rank correlation.

  • Class Detection / Subclass Detection (Hanawa et al., 2021): Measures the proportion of identical classes or subclasses in the top-1 training samples over the test dataset. If the attributions are based on similarity, they are expected to be predictive of the class of the test datapoint, as well as different subclasses under a single label.

  • Model Randomization (Hanawa et al., 2021): Measures the correlation between the original TDA and the TDA of a model with randomized weights. Since the attributions are expected to depend on model parameters, the correlation between original and randomized attributions should be low.

  • Top-K Cardinality (Barshan et al., 2020): Measures the cardinality of the union of the top-K training samples. Since the attributions are expected to be dependent on the test input, they are expected to vary heavily for different test points, resulting in a low overlap (high metric value).

  • Mislabeled Data Detection (Koh and Liang, 2017): Computes the proportion of noisy training labels detected as a function of the percentage of inspected training samples. The samples are inspected in order according to their global TDA ranking, which is computed using local attributions. This produces a cumulative mislabeling detection curve. We expect to see a curve that rapidly increases as we check more of the training data, thus we compute the area under this curve

  • Shortcut Detection (Yolcu et al., 2025): Assuming a known shortcut, or Clever-Hans effect has been identified in the model, this metric evaluates how effectively a TDA method can identify shortcut samples as the most influential in predicting cases with the shortcut artifact. This process is referred to as Domain Mismatch Debugging in the original paper.

  • Mixed Datasets (Hammoudeh and Lowd, 2022): In a setting where a model has been trained on two datasets: a clean dataset (e.g. CIFAR-10) and an adversarial (e.g. zeros from MNIST), this metric evaluates how well the model ranks the importance (attribution) of adversarial samples compared to clean samples when making predictions on an adversarial example.

  • Mean Reciprocal Rank (MRR) (Akyurek et al., 2022): For fact-tracing settings, measures the mean reciprocal rank of the highest-ranked entailing proponent across fact queries.

  • Recall@k (Akyurek et al., 2022): For fact-tracing settings, measures the proportion of facts for which an entailing proponent appears in the top-k retrievals.

  • Tail Patch (Chang et al., 2024): For fact-tracing settings, measures the incremental change in target-sequence probability after taking a single training step on retrieved proponents.

Metric interpretation guideline
Benchmark Output range Better
ClassDetection [0, 1] higher
SubclassDetection [0, 1] higher
MislabelingDetection [0, 1] higher
ShortcutDetection [0, 1] higher
MixedDatasets [0, 1] higher
TopKCardinality [0, 1] higher
ModelRandomization [-1, 1] closer to 0
LinearDatamodelingScore [-1, 1] higher
MRR [0, 1] higher
RecallAtK [0, 1] higher
TailPatch [-1, 1] higher

Benchmarks

quanda comes with a few pre-computed benchmarks that can be conveniently used for evaluation in a plug-and-play manner. We are planning to significantly expand the number of benchmarks in the future. The benchmark IDs listed below are to be passed to load_pretrained. The following benchmarks are currently available:

Metric Type Modality Benchmark IDs (Dataset / Model)
TopKCardinalityMetric Heuristic Vision mnist_top_k_cardinality (MNIST / LeNet)
cifar_top_k_cardinality (CIFAR-10 / ResNet-9)
awa2_top_k_cardinality (AWA2 / ResNet-50)
Text qnli_top_k_cardinality (QNLI / BERT)
ModelRandomizationMetric Heuristic Vision mnist_model_randomization (MNIST / LeNet)
cifar_model_randomization (CIFAR-10 / ResNet-9)
awa2_model_randomization (AWA2 / ResNet-50)
Text qnli_model_randomization (QNLI / BERT)
MixedDatasetsMetric Heuristic Vision mnist_mixed_datasets (MNIST / LeNet)
cifar_mixed_datasets (CIFAR-10 / ResNet-9)
awa2_mixed_datasets (AWA2 / ResNet-50)
Text qnli_mixed_datasets (QNLI / BERT)
ClassDetectionMetric Downstream-Task-Evaluator Vision mnist_class_detection (MNIST / LeNet)
cifar_class_detection (CIFAR-10 / ResNet-9)
awa2_class_detection (AWA2 / ResNet-50)
Text qnli_class_detection (QNLI / BERT)
SubclassDetectionMetric Downstream-Task-Evaluator Vision mnist_subclass_detection (MNIST / LeNet)
cifar_subclass_detection (CIFAR-10 / ResNet-9)
awa2_subclass_detection (AWA2 / ResNet-50)
MislabelingDetectionMetric Downstream-Task-Evaluator Vision mnist_mislabeling_detection (MNIST / LeNet)
cifar_mislabeling_detection (CIFAR-10 / ResNet-9)
awa2_mislabeling_detection (AWA2 / ResNet-50)
Text qnli_mislabeling_detection (QNLI / BERT)
ShortcutDetectionMetric Downstream-Task-Evaluator Vision mnist_shortcut_detection (MNIST / LeNet)
cifar_shortcut_detection (CIFAR-10 / ResNet-9)
awa2_shortcut_detection (AWA2 / ResNet-50)
MRRMetric Downstream-Task-Evaluator Causal LM gpt2_trex_openwebtext_ft_mrr (T-REx / GPT-2 fine-tuned on OpenWebText)
RecallAtKMetric Downstream-Task-Evaluator Causal LM gpt2_trex_openwebtext_ft_recall_at_k (T-REx / GPT-2 fine-tuned on OpenWebText)
TailPatchMetric Downstream-Task-Evaluator Causal LM gpt2_trex_openwebtext_ft_tail_patch (T-REx / GPT-2 fine-tuned on OpenWebText)
LinearDatamodelingMetric Ground Truth Vision mnist_linear_datamodeling (MNIST / LeNet)
cifar_linear_datamodeling (CIFAR-10 / ResNet-9)
awa2_linear_datamodeling (AWA2 / ResNet-50)
Text qnli_linear_datamodeling (QNLI / BERT)

🔬 Getting Started

Installation

To install quanda from a local clone of this repository, run:

pip install -e .

quanda requires Python 3.10, 3.11 or 3.12. It is recommended to use a virtual environment to install the package.

Basic Usage

In the following usage examples, we will be using the SimilarityInfluence data attribution from Captum.

Using Metrics

To begin using quanda metrics, you need the following components:

  • Trained PyTorch Model (model): A PyTorch model that has already been trained on a relevant dataset. As a placeholder, we used the layer name "avgpool" below. Please replace it with the name of one of the layers in your model.
  • PyTorch Dataset (train_set): The dataset used during the training of the model.
  • Test Dataset (eval_set): The dataset to be used as test inputs for generating explanations. Explanations are generated with respect to an output neuron corresponding to a certain class. This class can be selected to be the ground truth label of the test points, or the classes predicted by the model. In the following we will use the predicted labels to generate explanations. Next, we demonstrate how to evaluate explanations using the Model Randomization metric.
1. Import dependencies and library components
from torch.utils.data import DataLoader
from tqdm import tqdm

from quanda.explainers.wrappers import CaptumSimilarity
from quanda.metrics.heuristics import ModelRandomizationMetric
2. Create the explainer object

We now create our explainer. The device to be used by the explainer and metrics is inherited from the model, thus we set the model device explicitly.

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model.to(DEVICE)

explainer_kwargs = {
    "layers": "fc_2",
    "model_id": "default_model_id",
    "cache_dir": cache_dir,
}
explainer = CaptumSimilarity(
    model=model, train_dataset=dataset, **explainer_kwargs
)
3. Initialize the metric

The ModelRandomizationMetric needs to instantiate a new explainer to generate explanations for a randomized model. These will be compared with the explanations of the original model. Therefore, explainer_cls is passed directly to the metric along with initialization parameters of the explainer for the randomized model.

explainer_kwargs = {
    "layers": "fc_2",
    "model_id": "randomized_model_id",
    "cache_dir": cache_dir,
}
ckpt_path = os.path.join(cache_dir, "model_rand_ckpt.pth")
torch.save(model.state_dict(), ckpt_path)
model_rand = ModelRandomizationMetric(
    model=model,
    model_id="randomized_model_id",
    cache_dir=cache_dir,
    train_dataset=dataset,
    checkpoints=ckpt_path,
    explainer_cls=CaptumSimilarity,
    expl_kwargs=explainer_kwargs,
    correlation_fn="spearman",
    seed=42,
)
4. Iterate over test set to generate explanations and update the metric

We now start producing explanations with our TDA method. We go through the test set batch-by-batch. For each batch, we first generate the attributions using the predicted labels, and we then update the metric with the produced explanations to showcase how to concurrently handle the explanation and evaluation processes.

test_loader = DataLoader(eval_set, batch_size=batch_size, shuffle=False)
for test_data, _ in tqdm(test_loader):
    test_data = test_data.to(DEVICE)
    target = model(test_data).argmax(dim=-1)
    tda = explainer.explain(test_data=test_data, targets=target)
    model_rand.update(
        explanations=tda, test_data=test_data, test_targets=target
    )

print("Randomization metric output:", model_rand.compute())

Using Benchmarks

quanda benchmarks allow us to streamline the evaluation process by downloading the necessary data and models, and running the evaluation in a single command. The following code demonstrates how to use the mnist_subclass_detection benchmark:

1. Import dependencies and library components
from quanda.explainers.wrappers import CaptumSimilarity
from quanda.benchmarks.downstream_eval import SubclassDetection
2. Prepare arguments for the explainer object
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model.to(DEVICE)

explainer_kwargs = {
    "layers": "fc_2",
    "model_id": "default_model_id",
    "cache_dir": cache_dir,
}
3. Load a benchmark and score an explainer
subclass_detect = SubclassDetection.load_pretrained(
    bench_id="mnist_subclass_detection",
    cache_dir=cache_dir,
)
score = subclass_detect.evaluate(
    explainer_cls=CaptumSimilarity,
    expl_kwargs=explainer_kwargs,
    batch_size=batch_size,
    max_eval_n=16,
)["score"]
print(f"Subclass Detection Score: {score}")

Generating the benchmark object from scratch

While we provide a number of benchmarks with pre-computed assets, quanda Benchmark objects also expose a train interface for preparing benchmarks from scratch. To train a benchmark, specify its components in a single YAML file (see quanda/benchmarks/resources/configs).

1. Import dependencies and library components
import torch

from quanda.explainers.wrappers import CaptumSimilarity
from quanda.benchmarks.downstream_eval import MislabelingDetection
2. Prepare arguments for the explainer object
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model.to(DEVICE)

explainer_kwargs = {
    "layers": "fc_2",
    "model_id": "top_k_model",
    "cache_dir": cache_dir,
}
3. Train the model for the benchmark

For mislabeling detection, we will train a model from scratch using a dataset with a portion of labels flipped.

with open(
    "tests/assets/mnist_local_bench/83edb41-default_MislabelingDetection.yaml",
    "r",
) as f:
    mislabel_config = yaml.safe_load(f)

mislabel_config["bench_save_dir"] = os.path.join(
    cache_dir, "mislabeling_detection_bench"
)
mislabeling_detection = MislabelingDetection.train(
    mislabel_config,
    device=DEVICE,
)
4. Run the evaluation

We can now call the evaluate method to directly start the evaluation process on the benchmark.

score = mislabeling_detection.evaluate(
    explainer_cls=CaptumSimilarity,
    expl_kwargs=explainer_kwargs,
    batch_size=batch_size,
    max_eval_n=16,
)["score"]
print(f"Mislabeling Detection Score: {score}")

More detailed examples can be found in the tutorials folder. You can also use Hydra for benchmark training configuration, as shown in scripts/train.py.

Custom Explainers

In addition to the built-in explainers, quanda supports the evaluation of custom explainer methods. This section provides a guide on how to create a wrapper for a custom explainer that matches our interface.

Step 1. Create an explainer class

Your custom explainer should inherit from the base Explainer class provided by quanda. The first step is to initialize your custom explainer within the __init__ method.

from quanda.explainers.base import Explainer

class CustomExplainer(Explainer):
    def __init__(self, model, train_dataset, **kwargs):
        super().__init__(model, train_dataset, **kwargs)
        # Initialize your explainer here
Step 2. Implement the explain method

The core of your wrapper is the explain method. This function should take test samples and their corresponding target values as input and return a 2D tensor containing the influence scores.

  • test: The test batch for which explanations are generated.
  • targets: The target values for the explanations.

Ensure that the output tensor has the shape (test_samples, train_samples), where the entries in the train samples dimension are ordered in the same order as in the train_dataset that is being attributed.

def explain(
  self,
  test_data: torch.Tensor,
  targets: Union[List[int], torch.Tensor]
) -> torch.Tensor:
    # Compute your influence scores here
    return influence_scores
Step 3. Implement the self_influence method (Optional)

By default, quanda includes a built-in method for calculating self-influence scores. This base implementation computes all attributions over the training dataset, and collects the diagonal values in the attribution matrix. However, you can override this method to provide a more efficient implementation. This method should calculate how much each training sample influences itself and return a tensor of the computed self-influence scores.

def self_influence(self, batch_size: int = 1) -> torch.Tensor:
    # Compute your self-influence scores here
    return self_influence_scores

For detailed examples, we refer to the existing explainer wrappers in quanda.

⚠️ Usage Tips and Caveats

  • Controlled Setting Evaluation: Many metrics require access to ground truth labels for datasets, such as the indices of the "shortcut samples" in the Shortcut Detection metric, or the mislabeling (noisy) label indices for the Mislabeling Detection Metric. However, users often may not have access to these labels. To address this, we recommend either using one of our pre-built benchmark suites (see Benchmarks section) or generating (train method) a custom benchmark for comparing explainers. Benchmarks provide a controlled environment for systematic evaluation.

  • Explainer Caching: Many explainers in our library generate re-usable cache. The cache_dir and model_id parameters passed to various class instances are used to store these intermediary results. Ensure each experiment is assigned a unique combination of these arguments. Failing to do so could lead to incorrect reuse of cached results. If you wish to avoid re-using cached results, you can set the load_from_disk parameter to False.

  • Benchmark Dataset Caching: Benchmark initialization methods involve caching a HuggingFace dataset locally to the HF_HOME cache path. We recommend ensuring that the environment variable is set as needed and caching the dataset into the directory in advance of loading the benchmark.

  • Explainers Are Expensive To Calculate: Certain explainers, such as CaptumTracInCPFastRandProj, may lead to OutOfMemory (OOM) issues when applied to large models or datasets. In such cases, we recommend adjusting memory usage by either reducing the dataset size or using smaller models to avoid these issues.

📓 Tutorials

We have included a few tutorials to demonstrate the usage of quanda:

  • Explainers: shows how different explainers can be used with quanda
  • Metrics: shows how to use the metrics in quanda to evaluate the performance of a model
  • Benchmarks: shows how to use the benchmarking tools in quanda to evaluate a data attribution method

To install the library with tutorial dependencies, run:

pip install -e '.[tutorials]'

👩‍💻Contributing

We welcome contributions to quanda! You could contribute by:

  • Opening an issue to report a bug or request a feature.
  • Submitting a pull request to fix a bug, add a new explainer wrapper, a new metric, or another feature.

A detailed guide on how to contribute to quanda can be found here.

✉️ Contact

If you have any questions regarding the codebase, please open an issue or contact us via email at dilyabareeva@gmail.com or galip.uemit.yolcu@hhi.fraunhofer.de.

🔗Citation

@misc{bareeva2024quandainterpretabilitytoolkittraining,
      title={Quanda: An Interpretability Toolkit for Training Data Attribution Evaluation and Beyond},
      author={Dilyara Bareeva and Galip Ümit Yolcu and Anna Hedström and Niklas Schmolenski and Thomas Wiegand and Wojciech Samek and Sebastian Lapuschkin},
      year={2024},
      eprint={2410.07158},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2410.07158},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

quanda-0.5.0.tar.gz (152.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

quanda-0.5.0-py3-none-any.whl (199.8 kB view details)

Uploaded Python 3

File details

Details for the file quanda-0.5.0.tar.gz.

File metadata

  • Download URL: quanda-0.5.0.tar.gz
  • Upload date:
  • Size: 152.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for quanda-0.5.0.tar.gz
Algorithm Hash digest
SHA256 1f6dd3a2d978f49a0c6118911af9274befbf0f18147e72f2fa63704876cdd0cd
MD5 19c0a8bf1f092206319334a0c4ac1450
BLAKE2b-256 b49be761b0d729dc9f64bbed0724cddef26c4df8e64d4ece86536030599d076c

See more details on using hashes here.

File details

Details for the file quanda-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: quanda-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 199.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for quanda-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 358eb7a0d6c4fe38174aa203bc753b70704d9d808d0c8bc74d23c573be04a25d
MD5 fad84f4ab8458c73ef425119baf5ed89
BLAKE2b-256 6a3796e62b6c57f9112aaeb551685ec3672166486e4e67dd66048229db2decc5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page