Skip to main content

vimpy: perform inference on algorithm-agnostic variable importance in python

Project description

Python/vimpy: inference on algorithm-agnostic variable importance

PyPI version License: MIT

Software author: Brian Williamson

Methodology authors: Brian Williamson, Peter Gilbert, Noah Simon, Marco Carone

R package: https://github.com/bdwilliamson/vimp

Introduction

In predictive modeling applications, it is often of interest to determine the relative contribution of subsets of features in explaining an outcome; this is often called variable importance. It is useful to consider variable importance as a function of the unknown, underlying data-generating mechanism rather than the specific predictive algorithm used to fit the data. This package provides functions that, given fitted values from predictive algorithms, compute nonparametric estimates of variable importance based on $R^2$, deviance, classification accuracy, and area under the receiver operating characteristic curve, along with asymptotically valid confidence intervals for the true importance.

For more details, please see the accompanying manuscripts "Nonparametric variable importance assessment using machine learning techniques" by Williamson, Gilbert, Carone, and Simon (Biometrics, 2020), "A unified approach for inference on algorithm-agnostic variable importance" by Williamson, Gilbert, Simon, and Carone (arXiv, 2020), and "Efficient nonparametric statistical inference on population feature importance using Shapley values" by Williamson and Feng (arXiv, 2020; to appear in the Proceedings of the Thirty-seventh International Conference on Machine Learning [ICML 2020]).

Installation

You may install a stable release of vimpy using pip by running python pip install vimpy from a Terminal window. Alternatively, you may install within a virtualenv environment.

You may install the current dev release of vimpy by downloading this repository directly.

Issues

If you encounter any bugs or have any specific feature requests, please file an issue.

Example

This example shows how to use vimpy in a simple setting with simulated data and using a single regression function. For more examples and detailed explanation, please see the R vignette.

## load required libraries
import numpy as np
import vimpy
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import GridSearchCV

## -------------------------------------------------------------
## problem setup
## -------------------------------------------------------------
## define a function for the conditional mean of Y given X
def cond_mean(x = None):
    f1 = np.where(np.logical_and(-2 <= x[:, 0], x[:, 0] < 2), np.floor(x[:, 0]), 0)
    f2 = np.where(x[:, 1] <= 0, 1, 0)
    f3 = np.where(x[:, 2] > 0, 1, 0)
    f6 = np.absolute(x[:, 5]/4) ** 3
    f7 = np.absolute(x[:, 6]/4) ** 5
    f11 = (7./3)*np.cos(x[:, 10]/2)
    ret = f1 + f2 + f3 + f6 + f7 + f11
    return ret

## create data
np.random.seed(4747)
n = 100
p = 15
s = 1 # importance desired for X_1
x = np.zeros((n, p))
for i in range(0, x.shape[1]) :
    x[:,i] = np.random.normal(0, 2, n)

y = cond_mean(x) + np.random.normal(0, 1, n)

## -------------------------------------------------------------
## preliminary step: get regression estimators
## -------------------------------------------------------------
## use grid search to get optimal number of trees and learning rate
ntrees = np.arange(100, 500, 100)
lr = np.arange(.01, .1, .05)

param_grid = [{'n_estimators':ntrees, 'learning_rate':lr}]

## set up cv objects
cv_full = GridSearchCV(GradientBoostingRegressor(loss = 'ls', max_depth = 1), param_grid = param_grid, cv = 5)
cv_small = GridSearchCV(GradientBoostingRegressor(loss = 'ls', max_depth = 1), param_grid = param_grid, cv = 5)

## -------------------------------------------------------------
## get variable importance estimates
## -------------------------------------------------------------
# set seed
np.random.seed(12345)
## set up the vimp object
vimp = vimpy.vim(y = y, x = x, s = 1, pred_func = cv_full, measure_type = "r_squared")
## get the point estimate of variable importance
vimp.get_point_est()
## get the influence function estimate
vimp.get_influence_function()
## get a standard error
vimp.get_se()
## get a confidence interval
vimp.get_ci()
## do a hypothesis test, compute p-value
vimp.hypothesis_test(alpha = 0.05, delta = 0)
## display the estimates, etc.
vimp.vimp_
vimp.se_
vimp.ci_
vimp.p_value_
vimp.hyp_test_

## -------------------------------------------------------------
## using precomputed fitted values
## -------------------------------------------------------------
np.random.seed(12345)
folds_outer = np.random.choice(a = np.arange(2), size = n, replace = True, p = np.array([0.5, 0.5]))
## fit the full regression
cv_full.fit(x[folds_outer == 1, :], y[folds_outer == 1])
full_fit = cv_full.best_estimator_.predict(x[folds_outer == 1, :])

## fit the reduced regression
x_small = np.delete(x[folds_outer == 0, :], s, 1) # delete the columns in s
cv_small.fit(x_small, y[folds_outer == 0])
small_fit = cv_small.best_estimator_.predict(x_small)
## get variable importance estimates
np.random.seed(12345)
vimp_precompute = vimpy.vim(y = y, x = x, s = 1, f = full_fit, r = small_fit, measure_type = "r_squared", folds = folds_outer)
## get the point estimate of variable importance
vimp_precompute.get_point_est()
## get the influence function estimate
vimp_precompute.get_influence_function()
## get a standard error
vimp_precompute.get_se()
## get a confidence interval
vimp_precompute.get_ci()
## do a hypothesis test, compute p-value
vimp_precompute.hypothesis_test(alpha = 0.05, delta = 0)
## display the estimates, etc.
vimp_precompute.vimp_
vimp_precompute.se_
vimp_precompute.ci_
vimp_precompute.p_value_
vimp_precompute.hyp_test_

## -------------------------------------------------------------
## get variable importance estimates using cross-validation
## -------------------------------------------------------------
np.random.seed(12345)
## set up the vimp object
vimp_cv = vimpy.cv_vim(y = y, x = x, s = 1, pred_func = cv_full, V = 5, measure_type = "r_squared")
## get the point estimate
vimp_cv.get_point_est()
## get the standard error
vimp_cv.get_influence_function()
vimp_cv.get_se()
## get a confidence interval
vimp_cv.get_ci()
## do a hypothesis test, compute p-value
vimp_cv.hypothesis_test(alpha = 0.05, delta = 0)
## display estimates, etc.
vimp_cv.vimp_
vimp_cv.se_
vimp_cv.ci_
vimp_cv.p_value_
vimp_cv.hyp_test_

Logo

The logo was created using hexSticker, lisa, and a python image distributed under the CC0 license. Many thanks to the maintainers of these packages and the Color Lisa team.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vimpy-2.1.tar.gz (10.8 kB view details)

Uploaded Source

Built Distribution

vimpy-2.1-py3-none-any.whl (17.5 kB view details)

Uploaded Python 3

File details

Details for the file vimpy-2.1.tar.gz.

File metadata

  • Download URL: vimpy-2.1.tar.gz
  • Upload date:
  • Size: 10.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/47.1.1 requests-toolbelt/0.9.1 tqdm/4.46.1 CPython/3.6.5

File hashes

Hashes for vimpy-2.1.tar.gz
Algorithm Hash digest
SHA256 1f9f4f8e58bb469a5a942632e8ea4baa54e68e7c947384c2d2f7a801c18eb4a7
MD5 c999c897995aaa5fc2d07e97d602e626
BLAKE2b-256 a498c2b4d4dad768963ecc135bfe71a1828434292d4d1371f9031e6b4b723c1b

See more details on using hashes here.

File details

Details for the file vimpy-2.1-py3-none-any.whl.

File metadata

  • Download URL: vimpy-2.1-py3-none-any.whl
  • Upload date:
  • Size: 17.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/47.1.1 requests-toolbelt/0.9.1 tqdm/4.46.1 CPython/3.6.5

File hashes

Hashes for vimpy-2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 e36e49229b34adf6106cb1037f69d665a2440a37f47947ccc4cbc01f5766a8d7
MD5 8ea66516e56a804dea5c341d8dfd9f30
BLAKE2b-256 688ae5fb6ff8fbe9ea932cf4d7645d775989de9283feeef925db7760a68705ff

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page