Plotting tool for counterfactual explanations
Project description
CounterPlots: Plotting tool for counterfactuals
Counterplots is a Python package that allows you to plot counterfactuals with easy integration with any counterfactual generation algorithm.
Plot examples
Greedy Plot
The greedy plot shows the greediest (feature change with the highest impact towards the opposite class) path from the factual instance until it reaches the counterfactual.
CounterShapley Plot
This chart shows each counterfactual feature change contribution to the counterfactual prediction. It uses Shapley values to calculate the contribution of each feature change.
Constellation Plot
This chart shows the prediction score change for all possible feature change combinations.
Requirements
CounterPlots requires Python 3.8 or higher.
Installation
With pip:
pip install counterplots
Usage
To use CounterPlots, you just need the machine learning model predictor, and the factual and counterfactual points. The example below uses a simple mock model:
from counterplots import CreatePlot
import numpy as np
# Simple mock model for the predict_proba function which returns a probability for each input instance
def mock_predict_proba(data):
out = []
for x in data:
if list(x) == [0.0, 0.0, 0.0]:
out.append(0.0)
elif list(x) == [1.0, 0.0, 0.0]:
out.append(0.44)
elif list(x) == [0.0, 1.0, 0.0]:
out.append(0.4)
elif list(x) == [0.0, 0.0, 1.0]:
out.append(0.2)
elif list(x) == [1.0, 1.0, 0.0]:
out.append(0.3)
elif list(x) == [0.0, 1.0, 1.0]:
out.append(0.25)
elif list(x) == [1.0, 0.0, 1.0]:
out.append(0.4)
elif list(x) == [1.0, 1.0, 1.0]:
out.append(1.0)
return np.array(out)
# Factual Instance
factual = np.array([0, 0, 0])
# Counterfactual Instance
cf = np.array([1, 1, 1])
# Create the plot object
cf_plots = CreatePlot(
factual,
cf,
mock_predict_proba)
# Create the greedy plot
cf_plots.greedy('greedy_plot.png')
# Create the countershapley plot
cf_plots.countershapley('countershapley_plot.png')
# Create the constellation plot
cf_plots.constellation('constellation_plot.png')
# Print the countershapley values
print(cf_plots.countershapley_values())
In case you want to add custom names to the features, use the optional argument feature_names
:
cf_plots = CreatePlot(
factual,
cf,
mock_predict_proba,
feature_names=['feature1', 'feature2', 'feature3'])
In case you want to add custom labels to the factual and counterfactual points, use the optional argument class_names
:
cf_plots = CreatePlot(
factual,
cf,
mock_predict_proba,
class_names=['Factual', 'Counterfactual'])
Using with Scikit-Learn
CounterPlots can be used with any machine learning model that has a predict_proba
function. For example, with Scikit-Learn:
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from counterplots import CreatePlot
iris = load_iris()
X = iris.data
y = [0 if l == 0 else 1 for l in iris.target] # Makes it a binary classification problem
clf = RandomForestClassifier(max_depth=2, random_state=0)
clf.fit(X, y)
preds = clf.predict(X)
# For the factual point, takes an instance with 0 classification
factual = X[np.argwhere(preds == 0)[0]][0]
# For the counterfactual point, takes an instance with 1 classification
cf = X[np.argwhere(preds == 1)[0]][0]
cf_plots = CreatePlot(
factual,
cf,
clf.predict_proba,
feature_names=iris.feature_names,
class_names={0: 'Setosa', 1: 'Non-Setosa'}
)
# Create the greedy plot
cf_plots.greedy('iris_greedy_plot.png')
# Create the countershapley plot
cf_plots.countershapley('iris_countershapley_plot.png')
# Create the constellation plot
cf_plots.constellation('iris_constellation_plot.png')
# Print the countershapley values
print(cf_plots.countershapley_values())
Citation
If you use CounterPlots in your research, please cite the following paper:
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file counterplots-0.0.7.tar.gz
.
File metadata
- Download URL: counterplots-0.0.7.tar.gz
- Upload date:
- Size: 15.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.16
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1cc80ad3d555c762a988cb973e1c880716ce138328f6555534f7c9f4148eecba |
|
MD5 | 3738c3b7e59798b76e67395b805ebe72 |
|
BLAKE2b-256 | d7104350dcb4bbe49c50901e4427eb0ca98f1e322539d5d91af44cffc108fc04 |