A python version of R package for computing asymmetric Shapley values to assess causality in any trained machine learning model
Project description
shapflex
A python version of R package for computing asymmetric Shapley values to assess causality in any trained machine learning model
- Free software: MIT license
- Documentation: https://gregory-ch.github.io/shapflex
Warnings
This is the alpha version of porting https://github.com/nredell/shapFlex
Examples
#02.05.22
import pandas as pd
import numpy as np
from shapflex.shapflex import shapFlex_plus
from catboost import CatBoostClassifier
data = pd.read_csv('https://kolodezev.ru/download/data_adult.csv', index_col=0)
outcome_name = 'income'
outcome_col = pd.Series(data.columns)[data.columns==outcome_name].index[0]
X, y = data.drop(outcome_name, axis=1), data[outcome_name].values
cat_features = [inx for inx, value in zip(X.dtypes.index, X.dtypes) if value =='object']
model = CatBoostClassifier(iterations=100)
model.fit(X, y, cat_features=cat_features, verbose=False)
def predict_function(model, data):
#pd.DataFrame(model.predict_proba(X)).loc[:, 0][9] если запустить будет результат 0.98, что соответствует
#выводу для 9 номера который равен 0.98, неважно какой алгоритм, такая высокая степень уверенности
#позволяет идентифицировать выводимую колонку однозначно
return pd.DataFrame(model.predict_proba(data)[:, [0]])
explain, reference = data.iloc[:300, :data.shape[1]-1], data.iloc[:, :data.shape[1]-1]
sample_size = 50
target_features = pd.Series(["marital_status", "education", "relationship", "native_country",
"age", "sex", "race", "hours_per_week"])
causal = pd.DataFrame(
dict(cause=pd.Series(["age", "sex", "race", "native_country",
"age", "sex", "race", "native_country", "age",
"sex", "race", "native_country"]),
effect = pd.Series(np.concatenate([np.tile("marital_status", 4), np.tile("education", 4), np.tile("relationship", 4)])))
)
exmpl_of_test = shapFlex_plus(explain, model, predict_function, target_features=pd.Series(["marital_status", "education", "relationship", "native_country",
"age", "sex", "race", "hours_per_week"]), causal=causal, causal_weights = [1. for x in range(len(causal))])
result = exmpl_of_test.forward()
print(result.groupby('feature_name').mean())
Credits
This package was created with Cookiecutter and the giswqs/pypackage project template.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
shapflex-0.0.2.tar.gz
(10.0 kB
view hashes)