Skip to main content

A python version of R package for computing asymmetric Shapley values to assess causality in any trained machine learning model

Project description

shapflex

image

image

A python version of R package for computing asymmetric Shapley values to assess causality in any trained machine learning model

Warnings

This is the alpha version of porting https://github.com/nredell/shapFlex

Examples

#02.05.22


import pandas as pd
import numpy as np
from shapflex.shapflex import shapFlex_plus
from catboost import CatBoostClassifier 

data = pd.read_csv('https://kolodezev.ru/download/data_adult.csv', index_col=0)
outcome_name = 'income'
outcome_col = pd.Series(data.columns)[data.columns==outcome_name].index[0]
X, y = data.drop(outcome_name, axis=1), data[outcome_name].values
cat_features = [inx for inx, value in zip(X.dtypes.index, X.dtypes) if value =='object']
model = CatBoostClassifier(iterations=100)
model.fit(X, y, cat_features=cat_features, verbose=False)
def predict_function(model, data):
  #pd.DataFrame(model.predict_proba(X)).loc[:, 0][9] если запустить будет результат 0.98, что соответствует
  #выводу для 9 номера который равен 0.98, неважно какой алгоритм, такая высокая степень уверенности
  #позволяет идентифицировать выводимую колонку однозначно
  return pd.DataFrame(model.predict_proba(data)[:, [0]])


explain, reference = data.iloc[:300, :data.shape[1]-1], data.iloc[:, :data.shape[1]-1]
sample_size = 50
target_features = pd.Series(["marital_status", "education", "relationship",  "native_country",
                     "age", "sex", "race", "hours_per_week"])
causal = pd.DataFrame(
  dict(cause=pd.Series(["age", "sex", "race", "native_country",
              "age", "sex", "race", "native_country", "age",
              "sex", "race", "native_country"]),
  effect = pd.Series(np.concatenate([np.tile("marital_status", 4), np.tile("education", 4), np.tile("relationship", 4)])))
)
exmpl_of_test = shapFlex_plus(explain,  model, predict_function, target_features=pd.Series(["marital_status", "education", "relationship", "native_country",
"age", "sex", "race", "hours_per_week"]), causal=causal, causal_weights = [1. for x in range(len(causal))])
result = exmpl_of_test.forward()
print(result.groupby('feature_name').mean())


Credits

This package was created with Cookiecutter and the giswqs/pypackage project template.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

shapflex-0.0.2.tar.gz (10.0 kB view details)

Uploaded Source

File details

Details for the file shapflex-0.0.2.tar.gz.

File metadata

  • Download URL: shapflex-0.0.2.tar.gz
  • Upload date:
  • Size: 10.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.12

File hashes

Hashes for shapflex-0.0.2.tar.gz
Algorithm Hash digest
SHA256 2e82100d1c179c3bde5c1b479e74adeb0ace93ad237fba2a13835b9b2162eed5
MD5 26e0fffcc1f88c40707694c1e32156bf
BLAKE2b-256 5e2fc22df26150336b52319e5acf8e17860e1b83f8b87e6d91e9efb26be440c0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page