Skip to main content

Counterfactual explanations for GNNs based on the visual graph dataset format

Project description

made with python python 3.8 version

banner image

VGD Counterfactuals

Library for the generation and more importantly the easy visualization of Counterfactuals for Graph Neural Networks (GNNs) based on the VisualGraphDatasets dataset format.

What are Counterfactuals?

Counterfactuals are a method of explaining the predictions of complex machine learning models. For a certain prediction of a model, a counterfactual is an input element that is as similar as possible to the original input, but causes the largest possible deviation w.r.t. to the original model output prediction. They are sort of “counter examples” for the behavior of a model and can help to understand the decision boundary of the model.

The subject of this package are graph counterfactuals. They are generated by maximizing a customizable distance function in regards to the prediction output over all immediate neighbors of the original graph w.r.t. to the allowed, domain-specific graph edit operations.

Installation

git clone https://github.com/the16thpythonist/vgd_counterfactuals

Then in the main folder run a pip install:

cd vgd_counterfactuals
python3 -m pip install .

Afterwards, you can check the install by invoking the CLI:

python3 -m vgd_counterfactuals.cli --version
python3 -m vgd_counterfactuals.cli --help

Usage

Quickstart

The generation of counterfactual graphs is implemented via the CounterfactualGenerator class. The instantiation of one such object requires the following 4 main components:

  • processing: A visual_graph_dataset “Processing” object. These implement the necessary functionality to convert a domain-specific graph representation into the full graph structure for the machine learning models. These are shipped with each specific visual graph dataset.

  • model: The model to be explained. This model has to implement the visual_graph_dataset “PredictGraph” interface to ensure that the model can be directly queried with the vgd GraphDict representation of graph elements.

  • neighborhood_func: A function which receives the domain-specific representation of a graph as an input and is supposed to return a list of all the domain-specific representations of the immediate neighbors of that graph. The implementation for this is highly specific to each application domain.

  • distance_func: A function which receives to arguments: The prediction of the original element and the prediction of a neighbor and should return a single numeric value for the distance between the two predictions. The generator will maximize this distance measure.

After the generator object was instantiated, it can be used to create counterfactuals for any number of input elements using the generate method.

The following example shows a quickstart mock example of how all of this can be used. For more information have a look at the example modules provided in the examples folder of the repository.

import tempfile

from visual_graph_datasets.processing.molecules import MoleculeProcessing

from vgd_counterfactuals.base import CounterfactualGenerator
from vgd_counterfactuals.testing import MockModel
from vgd_counterfactuals.generate.molecules import get_neighborhood

processing = MoleculeProcessing()
model = MockModel()

generator = CounterfactualGenerator(
    processing=processing,
    model=model,
    neighborhood_func=get_neighborhood,
    distance_func=lambda orig, mod: abs(orig - mod),
)

with tempfile.TemporaryDirectory() as path:
    # The "generate" function will create all the possible neighbors of the
    # given "original" element, then query the model for to predict the
    # output for each of them, and sort them by their distance to the original.
    # The top k elements will be turned into a temporary visual graph dataset
    # within the given folder "path". That means in that folder two files will
    # be created per element: A metadata JSON file and a visualization PNG file.
    # Returns the dictionary for the loaded visual graph dataset.
    index_data_map = generator.generate(
        original='CCCCCC',
        # Path to the folder into which to save the vgd element files
        path=path,
        # The number of counterfactuals to be returned.
        # Elements will be sorted by their distance.
        k_results=10,
    )

    # The keys of the resulting dict are the integer indices and the values
    # are dicts themselves which describe the corresponding vgd elements.
    # These dicts contain for example the absolute path to the PNG file,
    # the full graph representation and additional metadata.
    print(f'generated {len(index_data_map)} counterfactuals:')
    for index, data in index_data_map.items():
        print(f' * {data["metadata"]["name"]} '
              f' - distance: {data["metadata"]["distance"]:.2f}')

Credits

  • PyComex is a micro framework which simplifies the setup, processing and management of computational experiments. It is also used to auto-generate the command line interface that can be used to interact with these experiments.

  • VisualGraphDatasets is a library which deals with the VGD dataset format. In this format, graph datasets for machine learning are represented by a folder, where each graph is represented by two files: A metadata JSON file that contains the full graph representation and additional metadata and a PNG visualization of the graph. The library aims to provide a framework for explainable graph machine learning which is easier to use and produces more reproducable results.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vgd_counterfactuals-0.1.5.tar.gz (425.5 kB view details)

Uploaded Source

Built Distribution

vgd_counterfactuals-0.1.5-py3-none-any.whl (433.6 kB view details)

Uploaded Python 3

File details

Details for the file vgd_counterfactuals-0.1.5.tar.gz.

File metadata

  • Download URL: vgd_counterfactuals-0.1.5.tar.gz
  • Upload date:
  • Size: 425.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.2 CPython/3.10.6 Linux/5.19.0-41-generic

File hashes

Hashes for vgd_counterfactuals-0.1.5.tar.gz
Algorithm Hash digest
SHA256 1904feefc1d5c72f904b3d0574344fa5a181b71b043ed0b8f1ec312687baa4f7
MD5 f202c7fc8743f99d42b8e8c1bff77a13
BLAKE2b-256 7b3e6eba5d1dbf61ed901f5458c40f7b78abb7e99943d6a5fae7ddadc04758cb

See more details on using hashes here.

File details

Details for the file vgd_counterfactuals-0.1.5-py3-none-any.whl.

File metadata

File hashes

Hashes for vgd_counterfactuals-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 1e072f55c4f1b33dff5146b8e9ffb6a4b1811a563fdfa62f015df05f1826e258
MD5 530362eb5f15b4999674bd7e8b13189c
BLAKE2b-256 888763a01bb837086a5e335564f93bdb97787f9b0d8468f05790436729933c4b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page