Plotting tool for counterfactual explanations
Project description
CounterPlots: Plotting tool for counterfactuals
Counterplots is a Python package that allows you to plot counterfactuals with easy integration with any counterfactual generation algorithm.
Installation
With pip:
pip install counterplots
Usage
To use CounterPlots, you just need the machine learning model predictor, and the factual and counterfactual points. The example below uses a simple mock model:
from counterplots import CreatePlot
import numpy as np
# Simple mock model for the predict_proba function which returns a probability for each input instance
def mock_predict_proba(data):
out = []
for x in data:
if list(x) == [0.0, 0.0, 0.0]:
out.append(0.0)
elif list(x) == [1.0, 0.0, 0.0]:
out.append(0.44)
elif list(x) == [0.0, 1.0, 0.0]:
out.append(0.4)
elif list(x) == [0.0, 0.0, 1.0]:
out.append(0.2)
elif list(x) == [1.0, 1.0, 0.0]:
out.append(0.3)
elif list(x) == [0.0, 1.0, 1.0]:
out.append(0.25)
elif list(x) == [1.0, 0.0, 1.0]:
out.append(0.4)
elif list(x) == [1.0, 1.0, 1.0]:
out.append(1.0)
return np.array(out)
# Factual Instance
factual = np.array([0, 0, 0])
# Counterfactual Instance
cf = np.array([1, 1, 1])
# Create the plot object
cf_plots = CreatePlot(
factual,
cf,
mock_predict_proba)
# Create the greedy plot
cf_plots.greedy('greedy_plot.png')
# Create the countershapley plot
cf_plots.countershapley('countershapley_plot.png')
# Create the constellation plot
cf_plots.constellation('constellation_plot.png')
# Print the countershapley values
print(cf_plots.countershapley_values())
In case you want to add custom names to the features, use the optional argument feature_names
:
cf_plots = CreatePlot(
factual,
cf,
mock_predict_proba,
feature_names=['feature1', 'feature2', 'feature3'])
In case you want to add custom labels to the factual and counterfactual points, use the optional argument class_names
:
cf_plots = CreatePlot(
factual,
cf,
mock_predict_proba,
class_names=['Factual', 'Counterfactual'])
Using with Scikit-Learn
CounterPlots can be used with any machine learning model that has a predict_proba
function. For example, with Scikit-Learn:
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from counterplots import CreatePlot
iris = load_iris()
X = iris.data
y = [0 if l == 0 else 1 for l in iris.target] # Makes it a binary classification problem
clf = RandomForestClassifier(max_depth=2, random_state=0)
clf.fit(X, y)
preds = clf.predict(X)
# For the factual point, takes an instance with 0 classification
factual = X[np.argwhere(preds == 0)[0]][0]
# For the counterfactual point, takes an instance with 1 classification
cf = X[np.argwhere(preds == 1)[0]][0]
cf_plots = CreatePlot(
factual,
cf,
clf.predict_proba,
feature_names=iris.feature_names,
class_names={0: 'Setosa', 1: 'Non-Setosa'}
)
# Create the greedy plot
cf_plots.greedy('iris_greedy_plot.png')
# Create the countershapley plot
cf_plots.countershapley('iris_countershapley_plot.png')
# Create the constellation plot
cf_plots.constellation('iris_constellation_plot.png')
# Print the countershapley values
print(cf_plots.countershapley_values())
Citation
If you use CounterPlots in your research, please cite the following paper:
0.0.2 / 2023-06-10
==================
- Updated documentation
0.0.1 / 2023-06-10
==================
- Initial package release
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.