Skip to main content

Scaling Shapley Value computation using Spark

Project description

Shparkley is a PySpark implementation of Shapley values which uses a monte-carlo approximation algorithm.

Given a dataset and machine learning model, Shparkley can compute Shapley values for all features for a feature vector. Shparkley also handles training weights and is model-agnostic.

Installation

pip install shparkley

Requirements

You must have Apache Spark installed on your machine/cluster.

Example Usage

from typing import List

from sklearn.base import ClassifierMixin

from affirm.model_interpretation.shparkley.estimator_interface import OrderedSet, ShparkleyModel
from affirm.model_interpretation.shparkley.spark_shapley import compute_shapley_for_sample


class MyShparkleyModel(ShparkleyModel):
    """
    You need to wrap your model with this interface (by subclassing ShparkleyModel)
    """
    def __init__(self, model: ClassifierMixin, required_features: OrderedSet):
        self._model = model
        self._required_features = required_features

    def predict(self, feature_matrix: List[OrderedDict]) -> List[float]:
        """
        Generates one prediction per row, taking in a list of ordered dictionaries (one per row).
        """
        pd_df = pd.DataFrame.from_dict(feature_matrix)
        preds = self._model.predict_proba(pd_df)[:, 1]
        return preds

    def _get_required_features(self) -> OrderedSet:
        """
        An ordered set of feature column names
        """
        return self._required_features

row = dataset.filter(dataset.row_id == 'xxxx').rdd.first()
shparkley_wrapped_model = MyShparkleyModel(my_model)

# You need to sample your dataset based on convergence criteria.
# More samples results in more accurate shapley values.
# Repartitioning and caching the sampled dataframe will speed up computation.
sampled_df = training_df.sample(0.1, True).repartition(75).cache()

shapley_scores_by_feature = compute_shapley_for_sample(
    df=sampled_df,
    model=shparkley_wrapped_model,
    row_to_investigate=row,
    weight_col_name='training_weight_column_name'
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

shparkley-1.0.1.tar.gz (10.5 kB view details)

Uploaded Source

Built Distribution

shparkley-1.0.1-py3-none-any.whl (13.0 kB view details)

Uploaded Python 3

File details

Details for the file shparkley-1.0.1.tar.gz.

File metadata

  • Download URL: shparkley-1.0.1.tar.gz
  • Upload date:
  • Size: 10.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.22.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.8.1

File hashes

Hashes for shparkley-1.0.1.tar.gz
Algorithm Hash digest
SHA256 466eece0b8f943ee7c01c5d50d4605896182edc8d36f838bb6b0653708ccde86
MD5 665bb177221a18e1f75795e22e70800f
BLAKE2b-256 45d3cc2bdceda131aee61f15e9e734d4ed99c1132e9cb5e9f9f70913174d98f1

See more details on using hashes here.

File details

Details for the file shparkley-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: shparkley-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 13.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.22.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.8.1

File hashes

Hashes for shparkley-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 49e57cb95049d83364d76f7d6894e0a87e6a3f439e49761f5ef74bc17184a57a
MD5 6c96706ea6c2fa363e4c90126b77daca
BLAKE2b-256 a228dcfafb75fe67b616afdf38da5178475b7a9cfd212cc8b956cded9dff19f9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page