Skip to main content

Calibratiing model scores/probabilites with pyspark dataframes

Project description

Model calibration with pyspark

Screenshot 2023-10-09 at 4 20 38 PM

This package provides a Betacal class which allows the user to fit/train the default beta calibration model on pyspark dataframes and predict calibrated scores

Setup

spark-calibration package is uploaded to PyPi and can be installed with this command:

pip install spark-calibration

Usage

Training

train_df should be a pyspark dataframe with score and label columns

from spark_calibration import Betacal
from spark_calibration import display_classification_calib_metrics
from spark_calibration import plot_calibration_curve


bc = Betacal(parameters="abm")


train_df = spark.read.parquet("s3://train/")

bc.fit(train_df) # training

print(bc.lr_model, a, b)

a,b -> coefficients of logistic regression model

lr_model -> pysparkML logistic regression model

Prediction

test_df is a pyspark dataframe with score as one of the columns. The predict function adds a new column prediction which has the calibrated score

test_df = spark.read.parquet("s3://test/")
test_df = bc.predict(test_df)

Pre & Post Calibration Classification Metrics

The test_df should have score, prediction & label columns. The display_classification_calib_metrics functions displays brier_score_loss, log_loss, area_under_PR_curve and area_under_ROC_curve

display_classification_calib_metrics(test_df)

Output

model brier score loss: 0.08072683729933376
calibrated model brier score loss: 0.01014015353257748
delta: -87.44%

model log loss: 0.3038106859864252
calibrated model log loss: 0.053275633947890755
delta: -82.46%

model aucpr: 0.03471287564672635
calibrated model aucpr: 0.03471240518472563
delta: -0.0%

model roc_auc: 0.7490639506966398
calibrated model roc_auc: 0.7490649764289607
delta: 0.0%

Plot the Calibration Curve

Computes true, predicted probabilites (pre & post calibration) using quantile binning strategy with 50 bins and plots the calibration curve

plot_calibration_curve(test_df)
Screenshot 2023-10-09 at 4 20 38 PM

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spark_calibration-1.0.8.tar.gz (7.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

spark_calibration-1.0.8-py3-none-any.whl (9.1 kB view details)

Uploaded Python 3

File details

Details for the file spark_calibration-1.0.8.tar.gz.

File metadata

  • Download URL: spark_calibration-1.0.8.tar.gz
  • Upload date:
  • Size: 7.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.4

File hashes

Hashes for spark_calibration-1.0.8.tar.gz
Algorithm Hash digest
SHA256 38d8f904dcf0ee009a42e0699fac8ab76e76a1de1bc3d3e4fe5192ee72a6b04d
MD5 11e2b39341c8ba9dd9a8bd32695309f3
BLAKE2b-256 2323d4f5330869c6277591c06570c955a4d31cec6196d64a4d811fc049115dfe

See more details on using hashes here.

File details

Details for the file spark_calibration-1.0.8-py3-none-any.whl.

File metadata

File hashes

Hashes for spark_calibration-1.0.8-py3-none-any.whl
Algorithm Hash digest
SHA256 b79beeb514b3759cf8389c6fbdf1c879864f13f77e2de05072fc05a658556c52
MD5 520c1646932914a364dff1b87b4b8e98
BLAKE2b-256 9a5f9ecdd741bcbea861111c058fea45de0810e1f4338fe14871ed170975d057

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page