Scaling Shapley Value computation using Spark
Project description
Shparkley is a PySpark implementation of Shapley values which uses a monte-carlo approximation algorithm.
Given a dataset and machine learning model, Shparkley can compute Shapley values for all features for a feature vector. Shparkley also handles training weights and is model agnostic.
Installation
pip install shparkley
Requirements
You must have Apache Spark installed on your machine/cluster.
Example Usage
from affirm.model_interpretation.shparkley.spark_shapley import (
compute_shapley_for_sample,
ShparkleyModel
)
class MyShparkleyModel(ShparkleyModel):
"""
You need to wrap your model with the ShparkleyModel interface.
"""
def get_required_features(self):
# type: () -> Set[str]
"""
Needs to return a set of feature names for the model.
"""
return ['feature-1', 'feature-2', 'feature-3']
def predict(self, feature_matrix):
# type: (List[Dict[str, Any]]) -> List[float]
"""
Wrapper function to convert the feature matrix into an acceptable format for your model.
This function should return the predicted probabilities.
The feature_matrix is a list of feature dictionaries.
Each dictionary has a mapping from the feature name to the value.
:return: Model predictions for all feature vectors
"""
# Convert the feature matrix into an appropriate form for your model object.
pd_df = pd.DataFrame.from_dict(feature_matrix)
preds = self._model.my_predict(pd_df)
return preds
row = dataset.filter(dataset.row_id = 'xxxx').rdd.first()
shparkley_wrapped_model = MyShparkleyModel(my_model)
# You need to sample your dataset based on convergence criteria.
# More samples results in more accurate shapley values.
# Repartitioning and caching the sampled dataframe will speed up computation.
sampled_df = training_df.sample(0.1, True).repartition(75).cache()
shapley_scores_by_feature = compute_shapley_for_sample(
df=sampled_df,
model=shparkley_wrapped_model,
row_to_investigate=row,
weight_col_name='training_weight_column_name'
)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
shparkley-1.0.0.tar.gz
(8.0 kB
view hashes)
Built Distribution
shparkley-1.0.0-py3-none-any.whl
(10.2 kB
view hashes)
Close
Hashes for shparkley-1.0.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1b558fae02eb8c78d22c7f768cecdf0c74ac94ebdefe50b4c287aa4de9cd5090 |
|
MD5 | 509cfbedea6e19ac21c5c98d8126dd9d |
|
BLAKE2b-256 | ff9e5100fdc0dc389fa8a31d65f70d9fed476416253adad3e698ce45975b02e2 |