Skip to main content

Plotting tool for counterfactual explanations

Project description



CounterPlots: Plotting tool for counterfactuals

License: MIT example workflow Code Coverage Known Vulnerabilities

Counterplots is a Python package that allows you to plot counterfactuals with easy integration with any counterfactual generation algorithm.

Plot examples

Greedy Plot

The greedy plot shows the greediest (feature change with the highest impact towards the opposite class) path from the factual instance until it reaches the counterfactual.

CounterShapley Plot

This chart shows each counterfactual feature change contribution to the counterfactual prediction. It uses Shapley values to calculate the contribution of each feature change.

Constellation Plot

This chart shows the prediction score change for all possible feature change combinations.

Requirements

CounterPlots requires Python 3.8 or higher.

Installation

With pip:

pip install counterplots

Usage

To use CounterPlots, you just need the machine learning model predictor, and the factual and counterfactual points. The example below uses a simple mock model:

from counterplots import CreatePlot
import numpy as np

# Simple mock model for the predict_proba function which returns a probability for each input instance
def mock_predict_proba(data):
    out = []
    for x in data:
        if list(x) == [0.0, 0.0, 0.0]:
            out.append(0.0)
        elif list(x) == [1.0, 0.0, 0.0]:
            out.append(0.44)
        elif list(x) == [0.0, 1.0, 0.0]:
            out.append(0.4)
        elif list(x) == [0.0, 0.0, 1.0]:
            out.append(0.2)
        elif list(x) == [1.0, 1.0, 0.0]:
            out.append(0.3)
        elif list(x) == [0.0, 1.0, 1.0]:
            out.append(0.25)
        elif list(x) == [1.0, 0.0, 1.0]:
            out.append(0.4)
        elif list(x) == [1.0, 1.0, 1.0]:
            out.append(1.0)
    return np.array(out)

# Factual Instance
factual = np.array([0, 0, 0])
# Counterfactual Instance
cf = np.array([1, 1, 1])

# Create the plot object
cf_plots = CreatePlot(
    factual,
    cf,
    mock_predict_proba)

# Create the greedy plot
cf_plots.greedy('greedy_plot.png')
# Create the countershapley plot
cf_plots.countershapley('countershapley_plot.png')
# Create the constellation plot
cf_plots.constellation('constellation_plot.png')

# Print the countershapley values
print(cf_plots.countershapley_values())

In case you want to add custom names to the features, use the optional argument feature_names:

cf_plots = CreatePlot(
    factual,
    cf,
    mock_predict_proba,
    feature_names=['feature1', 'feature2', 'feature3'])

In case you want to add custom labels to the factual and counterfactual points, use the optional argument class_names:

cf_plots = CreatePlot(
    factual,
    cf,
    mock_predict_proba,
    class_names=['Factual', 'Counterfactual'])

Using with Scikit-Learn

CounterPlots can be used with any machine learning model that has a predict_proba function. For example, with Scikit-Learn:

import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris

from counterplots import CreatePlot

iris = load_iris()

X = iris.data
y = [0 if l == 0 else 1 for l in iris.target] # Makes it a binary classification problem

clf = RandomForestClassifier(max_depth=2, random_state=0)
clf.fit(X, y)

preds = clf.predict(X)

# For the factual point, takes an instance with 0 classification
factual = X[np.argwhere(preds == 0)[0]][0]
# For the counterfactual point, takes an instance with 1 classification
cf = X[np.argwhere(preds == 1)[0]][0]

cf_plots = CreatePlot(
    factual,
    cf,
    clf.predict_proba,
    feature_names=iris.feature_names,
    class_names={0: 'Setosa', 1: 'Non-Setosa'}
)


# Create the greedy plot
cf_plots.greedy('iris_greedy_plot.png')
# Create the countershapley plot
cf_plots.countershapley('iris_countershapley_plot.png')
# Create the constellation plot
cf_plots.constellation('iris_constellation_plot.png')

# Print the countershapley values
print(cf_plots.countershapley_values())

Citation

If you use CounterPlots in your research, please cite the following paper:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

counterplots-0.0.7.tar.gz (15.9 kB view details)

Uploaded Source

File details

Details for the file counterplots-0.0.7.tar.gz.

File metadata

  • Download URL: counterplots-0.0.7.tar.gz
  • Upload date:
  • Size: 15.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.16

File hashes

Hashes for counterplots-0.0.7.tar.gz
Algorithm Hash digest
SHA256 1cc80ad3d555c762a988cb973e1c880716ce138328f6555534f7c9f4148eecba
MD5 3738c3b7e59798b76e67395b805ebe72
BLAKE2b-256 d7104350dcb4bbe49c50901e4427eb0ca98f1e322539d5d91af44cffc108fc04

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page