A SHAP Waterfall Chart for interpreting local differences between observations
Using pip (recommended)
pip install shapwaterfall==0.3.1
Many times when VMware Data Science Teams present their Machine Learning Classification models' propensity to buy scores (estimated probabilities) to stakeholders, stakeholders ask why a customer's propensity to buy is higher than the other customer. The stakeholder's question was our primary motivation.
We were further concerned with recent algorithm transparency language in the EU's General Data Protection Regulation (GDPR) and the California Consumer Privacy Act (CCPA). Although the 'right to explanation' is not necessarily clear, our desire is to act in good faith by providing local explainability and interpretability between two references, observations, clients, and customers.
This graph solution provides a local classification model interpretability between two observations, which internally we call customers. It uses each customer's estimated probability and fills the gap between the two probabilities with SHAP values that are ordered from higher to lower importance.
Update: This package works for all classification models. We added the Kernel Explainer. When using SVC ensure that
The package requires a classifier, training data, validation/test/scoring data, the two observations of interest (row index), and the desired number of important features. The package produces a Waterfall Chart.
shapwaterfall(clf, X_tng, X_val, index1, index2, num_features)
- clf: a classifier that is fitted to X_tng, training data.
- X_tng: the training data frame used to fit the model.
- X_val: the validation, test, or scoring data frame under observation.
- index1 and index2: the first and second row index numbers.
- num_features: the number of important features one wishes to display. They describe the local interpretability between to the two observations.
The shapwaterfall package requires the following python packages:
import pandas as pd import numpy as np import shap import matplotlib.pyplot as plt import waterfall_chart
Random Forest on WI Breast Cancer Data
# Scikit-Learn WI Breast Cancer Data Example !pip install shapwaterfall==0.3.1 # packages import pandas as pd import numpy as np from sklearn.datasets import load_breast_cancer from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import roc_auc_score from sklearn.model_selection import train_test_split import shap import matplotlib.pyplot as plt import waterfall_chart from shapwaterfall import shapwaterfall # models rf_clf = RandomForestClassifier(n_estimators=1666,\ max_features="auto", min_samples_split=2, min_samples_leaf=2,\ max_depth=20, bootstrap=True, n_jobs=1) # load and organize Wisconsin Breast Cancer Data data = load_breast_cancer() label_names = data['target_names'] labels = data['target'] feature_names = data['feature_names'] features = data['data'] # data splits X_tng, X_val, y_tng, y_val = train_test_split(features,\ labels, test_size=0.33, random_state=42) print(X_tng.shape) # (381, 30) print(X_val.shape) # (188, 30) X_tng = pd.DataFrame(X_tng) X_tng.columns = feature_names X_val = pd.DataFrame(X_val) X_val.columns = feature_names # fit RandomForest and measure AUC clf = rf_clf.fit(X_tng, y_tng) pred_rf = clf.predict_proba(X_val) score_rf = roc_auc_score(y_val,pred_rf[:,1]) print(score_rf, 'Random Forest AUC') # 0.9951893425434809 Random Forest AUC # Use Case 1 shapwaterfall(clf, X_tng, X_val, 5, 100, 5) shapwaterfall(clf, X_tng, X_val, 100, 5, 7) # Use Case 2 shapwaterfall(clf, X_tng, X_val, 36, 94, 5) shapwaterfall(clf, X_tng, X_val, 94, 36, 7) # Logistic Regression Example from sklearn.linear_model import LogisticRegression from sklearn.metrics import classification_report,\ confusion_matrix clf = LogisticRegression(solver='liblinear', random_state=0) clf.fit(X_tng, y_tng) y_pred = clf.predict(X_val) confusion_matrix(y_val, y_pred) # Use Case 3 shapwaterfall(clf, X_tng, X_val, 1, 44, 5) shapwaterfall(clf, X_tng, X_val, 44, 1, 5) # Support Vector Classification Example from sklearn.svm import SVC clf = SVC(probability=True) clf.fit(X_tng, y_tng) y_pred = clf.predict(X_val) confusion_matrix(y_val, y_pred) # Use Case 4 shapwaterfall(clf, X_tng, X_val, 1, 44, 5) shapwaterfall(clf, X_tng, X_val, 44, 1, 5)
John Halstead, firstname.lastname@example.org
Rajesh Vikraman, email@example.com
Ravi Prasad K, firstname.lastname@example.org
Release history Release notifications | RSS feed
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
|Filename, size||File type||Python version||Upload date||Hashes|
|Filename, size shapwaterfall-0.3.1-py3-none-any.whl (4.8 kB)||File type Wheel||Python version py3||Upload date||Hashes View|
|Filename, size shapwaterfall-0.3.1.tar.gz (4.7 kB)||File type Source||Python version None||Upload date||Hashes View|
Hashes for shapwaterfall-0.3.1-py3-none-any.whl