A python version of R package for computing asymmetric Shapley values to assess causality in any trained machine learning model
Project description
shapflex
A python version of R package for computing asymmetric Shapley values to assess causality in any trained machine learning model
- Free software: MIT license
- Documentation: https://gregory-ch.github.io/shapflex
Warnings
This is the alpha version of porting https://github.com/nredell/shapFlex
Examples
#02.05.22
import pandas as pd
import numpy as np
from shapflex.shapflex import shapFlex_plus
from catboost import CatBoostClassifier
data = pd.read_csv('https://kolodezev.ru/download/data_adult.csv', index_col=0)
outcome_name = 'income'
outcome_col = pd.Series(data.columns)[data.columns==outcome_name].index[0]
X, y = data.drop(outcome_name, axis=1), data[outcome_name].values
cat_features = [inx for inx, value in zip(X.dtypes.index, X.dtypes) if value =='object']
model = CatBoostClassifier(iterations=100)
model.fit(X, y, cat_features=cat_features, verbose=False)
def predict_function(model, data):
#pd.DataFrame(model.predict_proba(X)).loc[:, 0][9] если запустить будет результат 0.98, что соответствует
#выводу для 9 номера который равен 0.98, неважно какой алгоритм, такая высокая степень уверенности
#позволяет идентифицировать выводимую колонку однозначно
return pd.DataFrame(model.predict_proba(data)[:, [0]])
explain, reference = data.iloc[:300, :data.shape[1]-1], data.iloc[:, :data.shape[1]-1]
sample_size = 50
target_features = pd.Series(["marital_status", "education", "relationship", "native_country",
"age", "sex", "race", "hours_per_week"])
causal = pd.DataFrame(
dict(cause=pd.Series(["age", "sex", "race", "native_country",
"age", "sex", "race", "native_country", "age",
"sex", "race", "native_country"]),
effect = pd.Series(np.concatenate([np.tile("marital_status", 4), np.tile("education", 4), np.tile("relationship", 4)])))
)
exmpl_of_test = shapFlex_plus(explain, model, predict_function, target_features=pd.Series(["marital_status", "education", "relationship", "native_country",
"age", "sex", "race", "hours_per_week"]), causal=causal, causal_weights = [1. for x in range(len(causal))])
result = exmpl_of_test.forward()
print(result.groupby('feature_name').mean())
Credits
This package was created with Cookiecutter and the giswqs/pypackage project template.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
shapflex-0.0.2.tar.gz
(10.0 kB
view details)
File details
Details for the file shapflex-0.0.2.tar.gz
.
File metadata
- Download URL: shapflex-0.0.2.tar.gz
- Upload date:
- Size: 10.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2e82100d1c179c3bde5c1b479e74adeb0ace93ad237fba2a13835b9b2162eed5 |
|
MD5 | 26e0fffcc1f88c40707694c1e32156bf |
|
BLAKE2b-256 | 5e2fc22df26150336b52319e5acf8e17860e1b83f8b87e6d91e9efb26be440c0 |