Skip to main content

Counterfactual explanations for GNNs based on the visual graph dataset format

Project description

made with python python 3.8 version

banner image

VGD Counterfactuals

Library for the generation and more importantly the easy visualization of Counterfactuals for Graph Neural Networks (GNNs) based on the VisualGraphDatasets dataset format.

What are Counterfactuals?

Counterfactuals are a method of explaining the predictions of complex machine learning models. For a certain prediction of a model, a counterfactual is an input element that is as similar as possible to the original input, but causes the largest possible deviation w.r.t. to the original model output prediction. They are sort of “counter examples” for the behavior of a model and can help to understand the decision boundary of the model.

The subject of this package are graph counterfactuals. They are generated by maximizing a customizable distance function in regards to the prediction output over all immediate neighbors of the original graph w.r.t. to the allowed, domain-specific graph edit operations.

Installation

git clone https://github.com/the16thpythonist/vgd_counterfactuals

Then in the main folder run a pip install:

cd vgd_counterfactuals
python3 -m pip install .

Afterwards, you can check the install by invoking the CLI:

python3 -m vgd_counterfactuals.cli --version
python3 -m vgd_counterfactuals.cli --help

Usage

Quickstart

The generation of counterfactual graphs is implemented via the CounterfactualGenerator class. The instantiation of one such object requires the following 4 main components:

  • processing: A visual_graph_dataset “Processing” object. These implement the necessary functionality to convert a domain-specific graph representation into the full graph structure for the machine learning models. These are shipped with each specific visual graph dataset.

  • model: The model to be explained. This model has to implement the visual_graph_dataset “PredictGraph” interface to ensure that the model can be directly queried with the vgd GraphDict representation of graph elements.

  • neighborhood_func: A function which receives the domain-specific representation of a graph as an input and is supposed to return a list of all the domain-specific representations of the immediate neighbors of that graph. The implementation for this is highly specific to each application domain.

  • distance_func: A function which receives to arguments: The prediction of the original element and the prediction of a neighbor and should return a single numeric value for the distance between the two predictions. The generator will maximize this distance measure.

After the generator object was instantiated, it can be used to create counterfactuals for any number of input elements using the generate method.

The following example shows a quickstart mock example of how all of this can be used. For more information have a look at the example modules provided in the examples folder of the repository.

import tempfile

from visual_graph_datasets.processing.molecules import MoleculeProcessing

from vgd_counterfactuals.base import CounterfactualGenerator
from vgd_counterfactuals.testing import MockModel
from vgd_counterfactuals.generate.molecules import get_neighborhood

processing = MoleculeProcessing()
model = MockModel()

generator = CounterfactualGenerator(
    processing=processing,
    model=model,
    neighborhood_func=get_neighborhood,
    distance_func=lambda orig, mod: abs(orig - mod),
)

with tempfile.TemporaryDirectory() as path:
    # The "generate" function will create all the possible neighbors of the
    # given "original" element, then query the model for to predict the
    # output for each of them, and sort them by their distance to the original.
    # The top k elements will be turned into a temporary visual graph dataset
    # within the given folder "path". That means in that folder two files will
    # be created per element: A metadata JSON file and a visualization PNG file.
    # Returns the dictionary for the loaded visual graph dataset.
    index_data_map = generator.generate(
        original='CCCCCC',
        # Path to the folder into which to save the vgd element files
        path=path,
        # The number of counterfactuals to be returned.
        # Elements will be sorted by their distance.
        k_results=10,
    )

    # The keys of the resulting dict are the integer indices and the values
    # are dicts themselves which describe the corresponding vgd elements.
    # These dicts contain for example the absolute path to the PNG file,
    # the full graph representation and additional metadata.
    print(f'generated {len(index_data_map)} counterfactuals:')
    for index, data in index_data_map.items():
        print(f' * {data["metadata"]["name"]} '
              f' - distance: {data["metadata"]["distance"]:.2f}')

Credits

  • PyComex is a micro framework which simplifies the setup, processing and management of computational experiments. It is also used to auto-generate the command line interface that can be used to interact with these experiments.

  • VisualGraphDatasets is a library which deals with the VGD dataset format. In this format, graph datasets for machine learning are represented by a folder, where each graph is represented by two files: A metadata JSON file that contains the full graph representation and additional metadata and a PNG visualization of the graph. The library aims to provide a framework for explainable graph machine learning which is easier to use and produces more reproducable results.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vgd_counterfactuals-0.1.2.tar.gz (422.9 kB view details)

Uploaded Source

Built Distribution

vgd_counterfactuals-0.1.2-py3-none-any.whl (432.1 kB view details)

Uploaded Python 3

File details

Details for the file vgd_counterfactuals-0.1.2.tar.gz.

File metadata

  • Download URL: vgd_counterfactuals-0.1.2.tar.gz
  • Upload date:
  • Size: 422.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.2 CPython/3.10.6 Linux/5.19.0-41-generic

File hashes

Hashes for vgd_counterfactuals-0.1.2.tar.gz
Algorithm Hash digest
SHA256 f064104f88a221d237088516891ad2fff56797d79240967ca0a9ccd1716aaaf4
MD5 860fa6fcb7b4473cb4d288ef1bccd96f
BLAKE2b-256 2838f591497dbf58502a6139731630d7f9fdf672f8627e5e7046fe61348f2c44

See more details on using hashes here.

File details

Details for the file vgd_counterfactuals-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for vgd_counterfactuals-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 361b92ffdeecb5198b749405b6aa43e0c52b9d2ebad1f1b4cf8b9bb0270316d1
MD5 8d6f0a622400bad015586614b319d933
BLAKE2b-256 da3914fec9bfbdc6a4fd8fc18c6910d3758e93d8eb72435b55f4d4d2addde0ac

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page