Confidence intervals and p-values for sci-kit learn.
Project description
Statkit
Supplement your sci-kit learn models with 95 % confidence intervals, p-values, and decision curves.
Contents
Quickstart
- Estimate 95 % confidence intervals for your test scores.
For example, to compute a 95 % confidence interval of the area under the receiver operating characteristic curve (ROC AUC):
from sklearn.metrics import roc_auc_score
from statkit.non_parametric import bootstrap_score
y_prob = model.predict_proba(X_test)[:, 1]
auc_95ci = bootstrap_score(y_test, y_prob, metric=roc_auc_score)
print('Area under the ROC curve:', auc_95ci)
- Compute p-value to test if one model is significantly better than another.
For example, to test if the area under the receiver operating characteristic curve (ROC AUC) of model 1 is significantly larger than model 2:
from sklearn.metrics import roc_auc_score
from statkit.non_parametric import paired_permutation_test
y_pred_1 = model_1.predict_proba(X_test)[:, 1]
y_pred_2 = model_2.predict_proba(X_test)[:, 1]
p_value = paired_permutation_test(y_test, y_pred_1, y_pred_2, metric=roc_auc_score)
- Perform decision curve analysis by making net benefit plots of your sci-kit learn models. Compare the utility of different models and with decision policies to always or never take an action/intervention.
from matplotlib import pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from statkit.decision import NetBenefitDisplay
centers = [[0, 0], [1, 1]]
X_train, y_train = make_blobs(
centers=centers, cluster_std=1, n_samples=20, random_state=5
)
X_test, y_test = make_blobs(
centers=centers, cluster_std=1, n_samples=20, random_state=1005
)
baseline_model = LogisticRegression(random_state=5).fit(X_train, y_train)
y_pred_base = baseline_model.predict_proba(X_test)[:, 1]
tree_model = GradientBoostingClassifier(random_state=5).fit(X_train, y_train)
y_pred_tree = tree_model.predict_proba(X_test)[:, 1]
NetBenefitDisplay.from_predictions(y_test, y_pred_base, name='Baseline model')
NetBenefitDisplay.from_predictions(y_test, y_pred_tree, name='Gradient boosted trees', show_references=False, ax=plt.gca())
Detailed documentation can be on the Statkit API documentation pages.
Installation
pip3 install statkit
Support
You can open a ticket in the Issue tracker.
Contributing
We are open for contributions. If you open a pull request, make sure that your code is:
- Well documented,
- Code formatted with black,
- And contains an accompanying unit test.
Authors and acknowledgment
Hylke C. Donker
License
This code is licensed under the MIT license.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file statkit-1.0.0.tar.gz
.
File metadata
- Download URL: statkit-1.0.0.tar.gz
- Upload date:
- Size: 103.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7714c329a2a4798f388c9826acef5cba143b2e9b8899a19c9c1f87e462b896ef |
|
MD5 | 7e652332a6132e548ccc303324d89ecc |
|
BLAKE2b-256 | 51307c5654e808c2301b26f4d041f7f75e4eef1bfddae25f2eacb3e05b94ce38 |
File details
Details for the file statkit-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: statkit-1.0.0-py3-none-any.whl
- Upload date:
- Size: 21.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d73e5dc2495a94dd9c8f03653c283f459d77f58dbc4a296c75bf2c492d215de6 |
|
MD5 | f371fc417878cba88602d371cfd6ab7f |
|
BLAKE2b-256 | b527c6952afa38308198dd1a1df9f0ea1f1acddd5fb959262b822e4d08b9764c |