Skip to main content

A python version of R package for computing asymmetric Shapley values to assess causality in any trained machine learning model

Project description

shapflex

image

image

A python version of R package for computing asymmetric Shapley values to assess causality in any trained machine learning model

Warnings

This is the alpha version of porting https://github.com/nredell/shapFlex

Examples

#02.05.22


import pandas as pd
import numpy as np
from shapflex.shapflex import shapFlex_plus
from catboost import CatBoostClassifier 

data = pd.read_csv('https://kolodezev.ru/download/data_adult.csv', index_col=0)
outcome_name = 'income'
outcome_col = pd.Series(data.columns)[data.columns==outcome_name].index[0]
X, y = data.drop(outcome_name, axis=1), data[outcome_name].values
cat_features = [inx for inx, value in zip(X.dtypes.index, X.dtypes) if value =='object']
model = CatBoostClassifier(iterations=100)
model.fit(X, y, cat_features=cat_features, verbose=False)
def predict_function(model, data):
  #pd.DataFrame(model.predict_proba(X)).loc[:, 0][9] если запустить будет результат 0.98, что соответствует
  #выводу для 9 номера который равен 0.98, неважно какой алгоритм, такая высокая степень уверенности
  #позволяет идентифицировать выводимую колонку однозначно
  return pd.DataFrame(model.predict_proba(data)[:, [0]])


explain, reference = data.iloc[:300, :data.shape[1]-1], data.iloc[:, :data.shape[1]-1]
sample_size = 50
target_features = pd.Series(["marital_status", "education", "relationship",  "native_country",
                     "age", "sex", "race", "hours_per_week"])
causal = pd.DataFrame(
  dict(cause=pd.Series(["age", "sex", "race", "native_country",
              "age", "sex", "race", "native_country", "age",
              "sex", "race", "native_country"]),
  effect = pd.Series(np.concatenate([np.tile("marital_status", 4), np.tile("education", 4), np.tile("relationship", 4)])))
)
exmpl_of_test = shapFlex_plus(explain,  model, predict_function, target_features=pd.Series(["marital_status", "education", "relationship", "native_country",
"age", "sex", "race", "hours_per_week"]), causal=causal, causal_weights = [1. for x in range(len(causal))])
result = exmpl_of_test.forward()
print(result.groupby('feature_name').mean())


Credits

This package was created with Cookiecutter and the giswqs/pypackage project template.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

shapflex-0.0.2.tar.gz (10.0 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page