Counterfactual explanations for GNNs based on the visual graph dataset format
Project description
VGD Counterfactuals
Library for the generation and more importantly the easy visualization of Counterfactuals for Graph Neural Networks (GNNs) based on the VisualGraphDatasets dataset format.
What are Counterfactuals?
Counterfactuals are a method of explaining the predictions of complex machine learning models. For a certain prediction of a model, a counterfactual is an input element that is as similar as possible to the original input, but causes the largest possible deviation w.r.t. to the original model output prediction. They are sort of “counter examples” for the behavior of a model and can help to understand the decision boundary of the model.
The subject of this package are graph counterfactuals. They are generated by maximizing a customizable distance function in regards to the prediction output over all immediate neighbors of the original graph w.r.t. to the allowed, domain-specific graph edit operations.
Installation
git clone https://github.com/the16thpythonist/vgd_counterfactuals
Then in the main folder run a pip install:
cd vgd_counterfactuals
python3 -m pip install .
Afterwards, you can check the install by invoking the CLI:
python3 -m vgd_counterfactuals.cli --version
python3 -m vgd_counterfactuals.cli --help
Usage
Quickstart
The generation of counterfactual graphs is implemented via the CounterfactualGenerator class. The instantiation of one such object requires the following 4 main components:
processing: A visual_graph_dataset “Processing” object. These implement the necessary functionality to convert a domain-specific graph representation into the full graph structure for the machine learning models. These are shipped with each specific visual graph dataset.
model: The model to be explained. This model has to implement the visual_graph_dataset “PredictGraph” interface to ensure that the model can be directly queried with the vgd GraphDict representation of graph elements.
neighborhood_func: A function which receives the domain-specific representation of a graph as an input and is supposed to return a list of all the domain-specific representations of the immediate neighbors of that graph. The implementation for this is highly specific to each application domain.
distance_func: A function which receives to arguments: The prediction of the original element and the prediction of a neighbor and should return a single numeric value for the distance between the two predictions. The generator will maximize this distance measure.
After the generator object was instantiated, it can be used to create counterfactuals for any number of input elements using the generate method.
The following example shows a quickstart mock example of how all of this can be used. For more information have a look at the example modules provided in the examples folder of the repository.
import tempfile
from visual_graph_datasets.processing.molecules import MoleculeProcessing
from vgd_counterfactuals.base import CounterfactualGenerator
from vgd_counterfactuals.testing import MockModel
from vgd_counterfactuals.generate.molecules import get_neighborhood
processing = MoleculeProcessing()
model = MockModel()
generator = CounterfactualGenerator(
processing=processing,
model=model,
neighborhood_func=get_neighborhood,
distance_func=lambda orig, mod: abs(orig - mod),
)
with tempfile.TemporaryDirectory() as path:
# The "generate" function will create all the possible neighbors of the
# given "original" element, then query the model for to predict the
# output for each of them, and sort them by their distance to the original.
# The top k elements will be turned into a temporary visual graph dataset
# within the given folder "path". That means in that folder two files will
# be created per element: A metadata JSON file and a visualization PNG file.
# Returns the dictionary for the loaded visual graph dataset.
index_data_map = generator.generate(
original='CCCCCC',
# Path to the folder into which to save the vgd element files
path=path,
# The number of counterfactuals to be returned.
# Elements will be sorted by their distance.
k_results=10,
)
# The keys of the resulting dict are the integer indices and the values
# are dicts themselves which describe the corresponding vgd elements.
# These dicts contain for example the absolute path to the PNG file,
# the full graph representation and additional metadata.
print(f'generated {len(index_data_map)} counterfactuals:')
for index, data in index_data_map.items():
print(f' * {data["metadata"]["name"]} '
f' - distance: {data["metadata"]["distance"]:.2f}')
Credits
PyComex is a micro framework which simplifies the setup, processing and management of computational experiments. It is also used to auto-generate the command line interface that can be used to interact with these experiments.
VisualGraphDatasets is a library which deals with the VGD dataset format. In this format, graph datasets for machine learning are represented by a folder, where each graph is represented by two files: A metadata JSON file that contains the full graph representation and additional metadata and a PNG visualization of the graph. The library aims to provide a framework for explainable graph machine learning which is easier to use and produces more reproducable results.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file vgd_counterfactuals-0.1.2.tar.gz
.
File metadata
- Download URL: vgd_counterfactuals-0.1.2.tar.gz
- Upload date:
- Size: 422.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.4.2 CPython/3.10.6 Linux/5.19.0-41-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f064104f88a221d237088516891ad2fff56797d79240967ca0a9ccd1716aaaf4 |
|
MD5 | 860fa6fcb7b4473cb4d288ef1bccd96f |
|
BLAKE2b-256 | 2838f591497dbf58502a6139731630d7f9fdf672f8627e5e7046fe61348f2c44 |
File details
Details for the file vgd_counterfactuals-0.1.2-py3-none-any.whl
.
File metadata
- Download URL: vgd_counterfactuals-0.1.2-py3-none-any.whl
- Upload date:
- Size: 432.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.4.2 CPython/3.10.6 Linux/5.19.0-41-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 361b92ffdeecb5198b749405b6aa43e0c52b9d2ebad1f1b4cf8b9bb0270316d1 |
|
MD5 | 8d6f0a622400bad015586614b319d933 |
|
BLAKE2b-256 | da3914fec9bfbdc6a4fd8fc18c6910d3758e93d8eb72435b55f4d4d2addde0ac |