Skip to main content

A SHAP Waterfall Chart for interpreting local differences between observations

Project description

Install

Using pip (recommended)

pip install shapwaterfall==0.3.0

Introduction

Many times when VMware Data Science Teams present their Machine Learning Classification models' propensity to buy scores (estimated probabilities) to stakeholders, stakeholders ask why a customer's propensity to buy is higher than the other customer. The stakeholder's question was our primary motivation.

We were further concerned with recent algorithm transparency language in the EU's General Data Protection Regulation (GDPR) and the California Consumer Privacy Act (CCPA). Although the 'right to explanation' is not necessarily clear, our desire is to act in good faith by providing local explainability and interpretability between two references, observations, clients, and customers.

This graph solution provides a local classification model interpretability between two observations, which internally we call customers. It uses each customer's estimated probability and fills the gap between the two probabilities with SHAP values that are ordered from higher to lower importance.

Update: This package works for all classification models. We added the Kernel Explainer. When using SVC ensure that probability=True.

The package requires a classifier, training data, validation/test/scoring data, the two observations of interest (row index), and the desired number of important features. The package produces a Waterfall Chart.

Command

shapwaterfall(clf, X_tng, X_val, index1, index2, num_features)

Required

  • clf: a tree based classifier that is fitted to X_tng, training data.
  • X_tng: the training Data Frame used to fit the model.
  • X_val: the validation, test, or scoring Data Frame under observation. Note that the data frame must contain an extra column who's label is 'Reference'.
  • index1 and index2: the first and second index numbers.
  • num_features: the number of important features that describe the local interpretability between to the two observations.

Dependent Packages

The shapwaterfall package requires the following python packages:

import pandas as pd
import numpy as np
import shap
import matplotlib.pyplot as plt
import waterfall_chart

Examples

Random Forest on WI Breast Cancer Data

# Scikit-Learn WI Breast Cancer Data Example
!pip install shapwaterfall==0.3.0
# packages
import pandas as pd
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
import shap
import matplotlib.pyplot as plt
import waterfall_chart
from shapwaterfall import shapwaterfall

# models
rf_clf = RandomForestClassifier(n_estimators=1666,\
max_features="auto", min_samples_split=2, min_samples_leaf=2,\
max_depth=20, bootstrap=True, n_jobs=1)

# load and organize Wisconsin Breast Cancer Data
data = load_breast_cancer()
label_names = data['target_names']
labels = data['target']
feature_names = data['feature_names']
features = data['data']

# data splits
X_tng, X_val, y_tng, y_val = train_test_split(features,\
labels, test_size=0.33, random_state=42)

print(X_tng.shape) # (381, 30)
print(X_val.shape) # (188, 30)

X_tng = pd.DataFrame(X_tng)
X_tng.columns = feature_names
X_val = pd.DataFrame(X_val)
X_val.columns = feature_names

# fit RandomForest and measure AUC
clf = rf_clf.fit(X_tng, y_tng)
pred_rf = clf.predict_proba(X_val)
score_rf = roc_auc_score(y_val,pred_rf[:,1])
print(score_rf, 'Random Forest AUC')

# 0.9951893425434809 Random Forest AUC

# Use Case 1
shapwaterfall(clf, X_tng, X_val, 5, 100, 5)
shapwaterfall(clf, X_tng, X_val, 100, 5, 7)

# Use Case 2
shapwaterfall(clf, X_tng, X_val, 36, 94, 5)
shapwaterfall(clf, X_tng, X_val, 94, 36, 7)

# Logistic Regression Example
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report,\
confusion_matrix
clf = LogisticRegression(solver='liblinear', random_state=0)
clf.fit(X_tng, y_tng)
y_pred = clf.predict(X_val)
confusion_matrix(y_val, y_pred)

# Use Case 3
shapwaterfall(clf, X_tng, X_val, 1, 44, 5)
shapwaterfall(clf, X_tng, X_val, 44, 1, 5)

# Support Vector Classification Example
from sklearn.svm import SVC
clf = SVC(probability=True)
clf.fit(X_tng, y_tng)
y_pred = clf.predict(X_val)
confusion_matrix(y_val, y_pred)

# Use Case 4
shapwaterfall(clf, X_tng, X_val, 1, 44, 5)
shapwaterfall(clf, X_tng, X_val, 44, 1, 5)

Authors

John Halstead, jhalstead@vmware.com

Rajesh Vikraman, rvikraman@vmware.com

Ravi Prasad K, rkondapalli@vmware.com

Preprint available

https://www.researchgate.net/publication/354733308_SHAPWaterfall_A_Simplified_Visualization_Solution_for_Local_Interpretability_in_Machine_Learning_Models_Enabling_Precise_Business_Decision-Making_by_Visually_Comparing_Probabilities_of_Two_Observation

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

shapwaterfall-0.3.0.tar.gz (4.8 kB view details)

Uploaded Source

Built Distribution

shapwaterfall-0.3.0-py3-none-any.whl (4.9 kB view details)

Uploaded Python 3

File details

Details for the file shapwaterfall-0.3.0.tar.gz.

File metadata

  • Download URL: shapwaterfall-0.3.0.tar.gz
  • Upload date:
  • Size: 4.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.5.0.1 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.6

File hashes

Hashes for shapwaterfall-0.3.0.tar.gz
Algorithm Hash digest
SHA256 3c7b28561f3843eb8a31d18e19f068200a15ba5d2f6492dd57815e06f8e7f2f2
MD5 45609e950538866303b053d3a7d5133d
BLAKE2b-256 a6e2e5f8e5e28b72d1fc114ba093b7bf3e20b0aca5e385aae70117f6dd334289

See more details on using hashes here.

File details

Details for the file shapwaterfall-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: shapwaterfall-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 4.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.5.0.1 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.6

File hashes

Hashes for shapwaterfall-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ea3563b81341848ec5c5700a4fa3da20d0f1bf483e3ae268c9b44540d525bb63
MD5 b6f00fbd98c034716d09302557cd114d
BLAKE2b-256 dab65c42e4528f1451215cc7fae8a09677c04f8e158715fa6b56c5c1c7f5f57b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page