Skip to main content

A Python port of R package mshap to interpret combined model outputs.

Project description

mshap

codecov

This is a Python port of srmatth/mshap

The goal of mshap is to allow SHAP values for two-part models to be easily computed. A two-part model is one where the output from one model is multiplied by the output from another model. These are often used in the Actuarial industry, but have other use cases as well.

Installation

[WIP] Install mSHAP from pypi with the following code:

pip install mshap

Or the development version from github with:

pip install git+https://github.com/Diadochokinetic/mshap

Basic Use

We will demonstrate a simple use case on simulated data. Suppose that we wish to be able to predict to total amount of money a consumer will spend on a subscription to a software product. We might simulate 4 explanatory variables that looks like the following:

## R
set.seed(16)
age <- runif(1000, 18, 60)
income <- runif(1000, 50000, 150000)
married <- as.numeric(runif(1000, 0, 1) > 0.5)
sex <- as.numeric(runif(1000, 0, 1) > 0.5)
# For the sake of simplicity we will have these as numeric already, where 0 represents male and 1 represents female

Now because this is a contrived example, we will knowingly set the response variables as follows (suppose here that cost_per_month is usage based, so as to be continuous):

## R
cost_per_month <- (0.0006 * income - 0.2 * sex + 0.5 * married - 0.001 * age) + 10
num_months <- 15 * (0.001 * income  * 0.001 * sex * 0.5 * married - 0.05 * age)^2

Thus, we have our data. We will combine the covariates into a single data frame for ease of use in python.

## R
X <- data.frame(age, income, married, sex)

The end goal of this exercise is to predict the total revenue from the given customer, which mathematically will be cost_per_month * num_months. Instead of multiplying these two vectors together initially, we will instead create two models: one to predict cost_per_month and the other to predict num_months. We can then multiply the output of the two models together to get our predictions.

We now move over to python to create our two models and predict on the training sets:

## Python
X = r.X
y1 = r.cost_per_month
y2 = r.num_months

cpm_mod = sk.RandomForestRegressor(n_estimators = 100, max_depth = 10, max_features = 2)
cpm_mod.fit(X, y1)
#> RandomForestRegressor(max_depth=10, max_features=2)
nm_mod = sk.RandomForestRegressor(n_estimators = 100, max_depth = 10, max_features = 2)
nm_mod.fit(X, y2)
#> RandomForestRegressor(max_depth=10, max_features=2)
cpm_preds = cpm_mod.predict(X)
nm_preds = nm_mod.predict(X)

tot_rev = cpm_preds * nm_preds

We will now proceed to use TreeSHAP and subsequently mSHAP to explain the ultimate model predictions.

## Python

# because these are tree-based models, shap.Explainer uses TreeSHAP to calculate
# fast, exact SHAP values for each model individually
cpm_ex = shap.Explainer(cpm_mod)
cpm_shap = cpm_ex.shap_values(X)
cpm_expected_value = cpm_ex.expected_value

nm_ex = shap.Explainer(nm_mod)
nm_shap = nm_ex.shap_values(X)
nm_expected_value = nm_ex.expected_value
## R
final_shap <- mshap(
  shap_1 = py$cpm_shap, 
  shap_2 = py$nm_shap, 
  ex_1 = py$cpm_expected_value, 
  ex_2 = py$nm_expected_value
)

head(final_shap$shap_vals)
#> # A tibble: 6 x 4
#>       V1     V2     V3     V4
#>    <dbl>  <dbl>  <dbl>  <dbl>
#> 1  1149. -1200. 13.9   -11.8 
#> 2 -2711.  1149.  5.69  -11.2 
#> 3 -1027.  1301.  5.81    9.58
#> 4 -2064.  -879. -0.916 -22.7 
#> 5  3803.  2096. 37.7   -27.4 
#> 6 -2146.   897. 25.4   -14.3

final_shap$expected_value
#> [1] 4398.19

As a check, you can see that the expected value for mSHAP is indeed the expected value of the model across the training data.

## R
mean(py$tot_rev)
#> [1] 4398.19

We now have calculated the mSHAP values for the multiplied model outputs! This will allow us to explain our final model.

The mSHAP package comes with additional functions that can be used to visualize SHAP values in R. What is show here are the default outputs, but these functions return {ggplot2} objects that are easily customizable.

## R
summary_plot(
  variable_values = X,
  shap_values = final_shap$shap_vals, 
  names = c("age", "income", "married", "sex") # this is optional, since X has column names
)
## R
observation_plot(
  variable_values = X[23,],
  shap_values = final_shap$shap_vals[23,],
  expected_value = final_shap$expected_value,
  names = c("age", "income", "married", "sex")
)

For another, more complex, use case run vignette("mshap"). For more examples and options for plotting, run vignette("mshap_plots").

Citations

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mshap-0.2.1.tar.gz (10.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mshap-0.2.1-py3-none-any.whl (8.4 kB view details)

Uploaded Python 3

File details

Details for the file mshap-0.2.1.tar.gz.

File metadata

  • Download URL: mshap-0.2.1.tar.gz
  • Upload date:
  • Size: 10.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for mshap-0.2.1.tar.gz
Algorithm Hash digest
SHA256 28f0713f93918026f6f8fab0943367e625d3d4aa44eadebeffd27133c8630e43
MD5 646a15f4d2482d8d200a51765f2ebd85
BLAKE2b-256 5b359947b605c842e47bebad8462a568a61df11d959ea8779f7e09c7e32e742a

See more details on using hashes here.

File details

Details for the file mshap-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: mshap-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 8.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for mshap-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 9ff04a281f13dde460f9dc528b903de23a86e0094561a2d754c0fe9f8a2c3002
MD5 2bb9c57b3ddcb1dc50a720405587a660
BLAKE2b-256 6142d368907db20d4881a1110398946d90b6d50600a84bbb7972244647962b30

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page