Skip to main content

Software package for implementing shap-bootstrapping

Project description

Shapley Bootstrapping

  • Flexible
  • Effective
  • Explainable

Shapley bootstrapping is a novel machine learning methodology that harmonizes ensemble learners with Shapley values. For detailed explanation, see my [thesis].

Install

Shapley-bootstrapping can be installed via PyPi

pip install shap-bootstrap

This library automatically installs the following dependancies:

1.scipy
2.pandas
3.openml
4.xgboost
5.scikit-learn
6.seaborn
7.shap
8.matplotlib

Overview

This library includes the implementations of eight pipelines from [paper]. These pipelines are:

Thesis pipelines

Each of these pipelines are already implemented in the library(except 3 and 6, which are special cases) and can be directly used to train and predict over datasets.

Usage (Flexible)

Name of the module is shap_bootstrap. From this module, you can import the following sub-modules:

  1. building_blocks
  2. cluster
  3. custom_pipeline
  4. datasets
  5. utils
  6. visualisation

In order to follow the experiments from the [paper] or start with a quick example, you can use the custom_pipeline module for creating a pre-built pipeline using custom building blocks.

These 5 building_block classes are:

  1. ProcessingBlock
    • This block is for input pre-processing. Used for input scaling,imputing and train-test splitting
  2. ExplainerBlock
    • This block trains the Shap Explainer. Currently trains either Linear Regressor(keyword Linear) or XGBoost Regressor (keyword XGBoost), but will be changed in the future.
  3. ClusterBlock
    • This block takes two algorithms : One unsupervised clustering algorithm and one classifier model. In our research we have used K-Means and K-NN , but these models can be changed to any other for experimentation.
  4. EnsembleBlock
    • Ensemble block trains a set of individual XGBoost regressors over the clustered data. Currently either Linear Regressor(keyword Linear) or XGBoost Regressor (keyword XGBoost) but will be made parametric.
  5. ReduceBlock
    • Reduce block runs PCA over the data to project into lower dimensionality. Currently, PCA is fitted until 95% variance ratio is captured.

In an example below, we will implement Branch8 which uses dimensionality reduction, clustering and ensemble training using these building blocks.

Branch8 pipeline Workflow of branch8

The process sequence of this workflow is as follows:

  1. Explainer Block inputs datasets, fits a Shapley Explainer and calculates Shapley values of instances
  2. Reduce Block projects feature space to lower dimensionality using PCA with 95% explained variance ratio
  3. Cluster Block exercises given clustering algorithm(K-means in our case) and labels instances
  4. One-to-one mapper maps these labels back to original instances
  5. Ensemble block trains a model(XGBoost in this case) over each cluster
import math
from shap_bootstrap.building_blocks import *
from shap_bootstrap import datasets
from shap_bootstrap.custom_pipeline import B1_Branch_Pipeline,B8_Branch_Pipeline
from sklearn.cluster import KMeans
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn import metrics

# Returns boston dataset & train-test split
X,y,name = datasets.returnDataset(9)
X_train,X_test,y_train,y_test = prepare_pipeline_data(X,y,random_state = 42)

# Building blocks to be used in pipeline
# All algorithms can be changed with other models except PCA
explainer_type_b1 = 'XGBoost'
explainer_type_b8 = 'Linear'
model_type = 'XGBoost'
nClusters = 4

processing_block_b1 = ProcessingBlock()
explainer_block_b1 = ExplainerBlock(explainer_type_b1)
cluster_block_b1 = ClusterBlock(nClusters,KMeans(n_clusters = nClusters,random_state = 0),KNeighborsClassifier(n_neighbors = nClusters))
ensemble_block_b1 = EnsembleBlock(model_type)

# Instantianate Branch 8 pipeline
branch1 = B1_Branch_Pipeline(processing_block_b1,explainer_block_b1,cluster_block_b1,ensemble_block_b1)

processing_block_b8 = ProcessingBlock()
explainer_block_b8 = ExplainerBlock(explainer_type_b8)
reduce_block_b8 = ReduceBlock(PCA(1))
cluster_block_b8 = ClusterBlock(nClusters,KMeans(n_clusters = nClusters,random_state = 0),KNeighborsClassifier(n_neighbors = nClusters))
ensemble_block_b8 = EnsembleBlock(model_type)

# Instantianate Branch 8 pipeline
branch8 = B8_Branch_Pipeline(processing_block_b8,explainer_block_b8,reduce_block_b8,cluster_block_b8,ensemble_block_b8)

# Fit and predict
branch1.fit(X_train,y_train)
y_pred_b1 = branch1.predict(X_test)
err_b1 = math.sqrt(metrics.mean_squared_error(y_test,y_pred_b1))
branch8.fit(X_train,y_train)
y_pred_b8 = branch8.predict(X_test)
err_b8 = math.sqrt(metrics.mean_squared_error(y_test,y_pred_b8))

This code snippet implements branch8, trains over the student_grades dataset and makes predictions. Now, we can further evaluate our predictions.

Evaluation of results (Effective)

We can visualise model predictions via:

ax = plt.subplot()
ax.scatter(x=y_test,y = y_pred_b1)
ax.scatter(x = y_test,y = y_pred_b8)
ax.plot([0,25],[0,25],color = 'red',linestyle='--')
ax.set_xlabel('True label')
ax.set_ylabel('Pipeline predictions')
ax.set_title('Divergence of predictions from true label')
new_labels = ['Identity', 'XGBoost Model - RMSE: {:.3f}'.format(err_b1), 'Shap-bootstrap - RMSE: {:.3f}'.format(err_b8)]
ax.legend(new_labels)

Scatter plot of error

Plot of prediction error, blue dots is our model, red line is the best prediction

In the plot, we observe that our proposed methodology can improve the performance of an existing XGBoost model, by fitting the data better (Orange points lie closer to the identity line). We improved the RMSE value by 5%, but this is not a major improvement.

Interpreting Features (Explainable)

We can further check the Shapley values as further interpretation of feature importances:

explainer = branch8.explainer_block.explainer
shap.initjs()
shap.summary_plot(explainer.shap_values(X_train),X_train,X.columns,max_display=10)

which outputs following plot :

Shapley summary plot

Here, we see the features ranked descending by the sum of Shapley value magnitudes over all samples. Namely, they are ranked in feature importance top to bottom. For example, feature G2 is positively correlated with output, such that high values of G2 increase the output label and vice versa.

With this capability, we can inspect on feature importances, which will aid the interpretability of the model.

Model-agnostic functionality (Flexible)

The training pipelines take machine learning models as arguments in instantianation. Therefore, it is possible to run the experimentations with different models. Shap-bootstrap offers a flexbility in implementation.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

shap_bootstrap-0.0.13.tar.gz (20.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

shap_bootstrap-0.0.13-py3-none-any.whl (2.7 MB view details)

Uploaded Python 3

File details

Details for the file shap_bootstrap-0.0.13.tar.gz.

File metadata

  • Download URL: shap_bootstrap-0.0.13.tar.gz
  • Upload date:
  • Size: 20.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.21.0 setuptools/46.3.0 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.6.8

File hashes

Hashes for shap_bootstrap-0.0.13.tar.gz
Algorithm Hash digest
SHA256 03ad5060619d1916cc8f5188ac9af3b89c986445f57eea4d7e8849becb419888
MD5 a79a4cba55ab3b1d5b18143130e8c867
BLAKE2b-256 e9e027a7adaffff0cd07f415936e2ba5b3145c1b2aa48c0b496072c711e87bf9

See more details on using hashes here.

File details

Details for the file shap_bootstrap-0.0.13-py3-none-any.whl.

File metadata

  • Download URL: shap_bootstrap-0.0.13-py3-none-any.whl
  • Upload date:
  • Size: 2.7 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.21.0 setuptools/46.3.0 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.6.8

File hashes

Hashes for shap_bootstrap-0.0.13-py3-none-any.whl
Algorithm Hash digest
SHA256 81debb3a4dd2fbcba5bcc3860b3d8106b2d58b40087fe1556d41140af7b23cc5
MD5 5365bc8cedb5cc8f9fc1d312a066f73e
BLAKE2b-256 20ae6f99f60fdb7eac641b011d93c9ab30f2bdcbb5a562878a05a8d9425b2e36

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page