Skip to main content

Scaling Shapley Value computation using Spark

Project description

Shparkley is a PySpark implementation of Shapley values which uses a monte-carlo approximation algorithm.

Given a dataset and machine learning model, Shparkley can compute Shapley values for all features for a feature vector. Shparkley also handles training weights and is model-agnostic.

Installation

pip install shparkley

Requirements

You must have Apache Spark installed on your machine/cluster.

Example Usage

from typing import List

from sklearn.base import ClassifierMixin

from affirm.model_interpretation.shparkley.estimator_interface import OrderedSet, ShparkleyModel
from affirm.model_interpretation.shparkley.spark_shapley import compute_shapley_for_sample


class MyShparkleyModel(ShparkleyModel):
    """
    You need to wrap your model with this interface (by subclassing ShparkleyModel)
    """
    def __init__(self, model: ClassifierMixin, required_features: OrderedSet):
        self._model = model
        self._required_features = required_features

    def predict(self, feature_matrix: List[OrderedDict]) -> List[float]:
        """
        Generates one prediction per row, taking in a list of ordered dictionaries (one per row).
        """
        pd_df = pd.DataFrame.from_dict(feature_matrix)
        preds = self._model.predict_proba(pd_df)[:, 1]
        return preds

    def _get_required_features(self) -> OrderedSet:
        """
        An ordered set of feature column names
        """
        return self._required_features

row = dataset.filter(dataset.row_id == 'xxxx').rdd.first()
shparkley_wrapped_model = MyShparkleyModel(my_model)

# You need to sample your dataset based on convergence criteria.
# More samples results in more accurate shapley values.
# Repartitioning and caching the sampled dataframe will speed up computation.
sampled_df = training_df.sample(0.1, True).repartition(75).cache()

shapley_scores_by_feature = compute_shapley_for_sample(
    df=sampled_df,
    model=shparkley_wrapped_model,
    row_to_investigate=row,
    weight_col_name='training_weight_column_name'
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

shparkley-1.0.1.tar.gz (10.5 kB view hashes)

Uploaded Source

Built Distribution

shparkley-1.0.1-py3-none-any.whl (13.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page