Skip to main content

For calculating global feature importance using Shapley values.

Project description

SAGE

SAGE (Shapley Additive Global importancE) is a game-theoretic approach for understanding black-box machine learning models. It quantifies each feature's importance based on how much predictive power it contributes, and it accounts for complex feature interactions using the Shapley value.

SAGE was introduced in this paper, but if you're new to using Shapley values you may want to start by reading this blog post.

Install

The easiest way to get started is to install the sage-importance package with pip:

pip install sage-importance

Alternatively, you can clone the repository and install the package in your Python environment as follows:

git clone https://github.com/iancovert/sage.git
cd sage
pip install .

Usage

SAGE is model-agnostic, so you can use it with any kind of machine learning model (linear models, GBMs, neural networks, etc). All you need to do is set up an imputer to handle held out features, and then estimate the Shapley values:

import sage

# Get data
x, y = ...
feature_names = ...

# Get model
model = ...

# Set up an imputer to handle missing features
imputer = sage.MarginalImputer(model, x[:128])

# Set up an estimator
estimator = sage.PermutationEstimator(imputer, 'mse')

# Calculate SAGE values
sage_values = estimator(x, y)
sage_values.plot(feature_names)

The result will look like this:

Our implementation supports several features to make estimating the Shapley values easier:

  • Uncertainty estimation: confidence intervals are provided for each feature's importance value.
  • Convergence detection: convergence is determined based on the size of the confidence intervals, and a progress bar displays the estimated time until convergence.
  • Model conversion: our back-end requires models to be represented in a consistent format, and this conversion step is performed automatically for XGBoost, CatBoost, LightGBM, sklearn and PyTorch models. If you're using a different kind of model, it needs to be converted to a callable function (see here for examples).

Examples

Check out the following notebooks to get started:

  • Bike: a simple example using XGBoost, shows how to calculate SAGE values and Shapley Effects (an alternative explanation when no labels are available)
  • Credit: generate explanations using a surrogate model to approximate the conditional distribution (using CatBoost)
  • Airbnb: calculate SAGE values with grouped features (using a PyTorch MLP)
  • Bank: a model monitoring example that uses SAGE to identify features that hurt the model's performance (using CatBoost)
  • MNIST: shows strategies to accelerate convergence for datasets with many features (feature grouping, different imputing setups)
  • Consistency: verifies that our various Shapley value estimators return the same results (see the estimators listed below)
  • Calibration: verifies that SAGE's confidence intervals are representative of the uncertainty across runs
  • Losses: shows how SAGE can be used in classification with alternative loss functions.

If you want to replicate the experiments described in our paper, see this separate repository.

More details

This repository provides some flexibility in how you generate explanations. You can make several choices when generating explanations.

1. Feature removal approach

The original SAGE paper proposes marginalizing out missing features using their conditional distribution. Since this is challenging to implement in practice, several approximations are available. For example, you can:

  1. Use default values for missing features (see MNIST for an example). This is a fast but low-quality approximation.
  2. Sample features from the marginal distribution (see Bike for an example). This approximation is discussed in the SAGE paper.
  3. Train a supervised surrogate model (see Credit for an example). This approach is described in this paper, and it can provide a better approximation than the other approaches. However, it requires training an additional model (typically a neural network).
  4. Train a model that accommodates missingness. This approach is not shown here, but it's described in this paper.

2. Explanation type

Two types of explanations can be calculated, both based on Shapley values:

  1. SAGE. This approach quantifies how much each feature improves the model's performance (this is the default).
  2. Shapley Effects. Described in this paper, this explanation method quantifies the model's sensitivity to each feature. Since Shapley Effects is closely related to SAGE (see here for details), our implementation generates this type of explanation when labels are not provided. See the Bike notebook for an example.

3. Shapley value estimator

Shapley values are computationally costly to calculate exactly, so we provide several estimation approaches:

  1. Permutation sampling. This is the approach described in the original paper (see PermutationEstimator).
  1. KernelSAGE. This is a linear regression-based estimator that's similar to KernelSHAP (see KernelEstimator). It's described in this paper, and the Bank notebook shows an example.
  2. Iterated sampling. This is a variation on the permutation sampling approach where we calculate Shapley values sequentially for each feature (see IteratedEstimator). This enables faster convergence for features with low variance, but it can result in wider confidence intervals.
  3. Sign estimation. This method estimates SAGE values to a lower precision by focusing only on their sign (i.e., whether they help or hurt performance). It's implemented in SignEstimator, and the Bank notebook shows an example.

The results from each approach should be identical (see Consistency), but there may be differences in convergence speed. Permutation sampling is a good approach to start with. KernelSAGE may converge a bit faster, but the uncertainty is spread more evenly among the features rather than being highest for more important features.

4. Grouped features

Rather than removing features individually, you can specify groups of features to be removed jointly. This will likely speed up convergence because there are fewer feature subsets. See Airbnb for an example.

Authors

References

Ian Covert, Scott Lundberg, Su-In Lee. "Understanding Global Feature Contributions With Additive Importance Measures." NeurIPS 2020

Ian Covert, Scott Lundberg, Su-In Lee. "Explaining by Removing: A Unified Framework for Model Explanation." JMLR 2021

Ian Covert, Su-In Lee. "Improving KernelSHAP: Practical Shapley Value Estimation via Linear Regression." AISTATS 2021

Art Owen. "Sobol' Indices and Shapley value." SIAM 2014

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sage-importance-0.0.6.tar.gz (24.0 kB view details)

Uploaded Source

Built Distribution

sage_importance-0.0.6-py3-none-any.whl (27.3 kB view details)

Uploaded Python 3

File details

Details for the file sage-importance-0.0.6.tar.gz.

File metadata

  • Download URL: sage-importance-0.0.6.tar.gz
  • Upload date:
  • Size: 24.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for sage-importance-0.0.6.tar.gz
Algorithm Hash digest
SHA256 bc7c17c40ba5f67a8163cd21fae7ced38d3aef75c21b1ab938240c624892135f
MD5 bd547e9ec4a2fd143a2970a723ad7039
BLAKE2b-256 712eb33c5e77dccd84f6ba0b78b62ce543f03af1ab0ef9a8909246a05d7ff24e

See more details on using hashes here.

File details

Details for the file sage_importance-0.0.6-py3-none-any.whl.

File metadata

File hashes

Hashes for sage_importance-0.0.6-py3-none-any.whl
Algorithm Hash digest
SHA256 db94171ce7ac43176ad5cdced536f4fe157d14ec3c0d9b044bdab24e6c2fb3a5
MD5 632068a3bd08fad01f5476157d67b1de
BLAKE2b-256 9273fc28f24d44ce0c07b844397de8bd2ed4b6f8d183ae56cae13371e0909a2a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page