Scaling Shapley Value computation using Spark
Project description
Shparkley is a PySpark implementation of Shapley values which uses a monte-carlo approximation algorithm.
Given a dataset and machine learning model, Shparkley can compute Shapley values for all features for a feature vector. Shparkley also handles training weights and is model-agnostic.
Installation
pip install shparkley
Requirements
You must have Apache Spark installed on your machine/cluster.
Example Usage
from typing import List
from sklearn.base import ClassifierMixin
from affirm.model_interpretation.shparkley.estimator_interface import OrderedSet, ShparkleyModel
from affirm.model_interpretation.shparkley.spark_shapley import compute_shapley_for_sample
class MyShparkleyModel(ShparkleyModel):
"""
You need to wrap your model with this interface (by subclassing ShparkleyModel)
"""
def __init__(self, model: ClassifierMixin, required_features: OrderedSet):
self._model = model
self._required_features = required_features
def predict(self, feature_matrix: List[OrderedDict]) -> List[float]:
"""
Generates one prediction per row, taking in a list of ordered dictionaries (one per row).
"""
pd_df = pd.DataFrame.from_dict(feature_matrix)
preds = self._model.predict_proba(pd_df)[:, 1]
return preds
def _get_required_features(self) -> OrderedSet:
"""
An ordered set of feature column names
"""
return self._required_features
row = dataset.filter(dataset.row_id == 'xxxx').rdd.first()
shparkley_wrapped_model = MyShparkleyModel(my_model)
# You need to sample your dataset based on convergence criteria.
# More samples results in more accurate shapley values.
# Repartitioning and caching the sampled dataframe will speed up computation.
sampled_df = training_df.sample(0.1, True).repartition(75).cache()
shapley_scores_by_feature = compute_shapley_for_sample(
df=sampled_df,
model=shparkley_wrapped_model,
row_to_investigate=row,
weight_col_name='training_weight_column_name'
)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
shparkley-1.0.1.tar.gz
(10.5 kB
view details)
Built Distribution
shparkley-1.0.1-py3-none-any.whl
(13.0 kB
view details)
File details
Details for the file shparkley-1.0.1.tar.gz
.
File metadata
- Download URL: shparkley-1.0.1.tar.gz
- Upload date:
- Size: 10.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.22.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.8.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 466eece0b8f943ee7c01c5d50d4605896182edc8d36f838bb6b0653708ccde86 |
|
MD5 | 665bb177221a18e1f75795e22e70800f |
|
BLAKE2b-256 | 45d3cc2bdceda131aee61f15e9e734d4ed99c1132e9cb5e9f9f70913174d98f1 |
File details
Details for the file shparkley-1.0.1-py3-none-any.whl
.
File metadata
- Download URL: shparkley-1.0.1-py3-none-any.whl
- Upload date:
- Size: 13.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.22.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.8.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 49e57cb95049d83364d76f7d6894e0a87e6a3f439e49761f5ef74bc17184a57a |
|
MD5 | 6c96706ea6c2fa363e4c90126b77daca |
|
BLAKE2b-256 | a228dcfafb75fe67b616afdf38da5178475b7a9cfd212cc8b956cded9dff19f9 |