Precision-Recall-Gain curves and metrics for scikit-learn
Project description
sklearn_prg
A Python module implementing Precision-Recall-Gain (PRG) curves and metrics compatible with scikit-learn.
Why Precision-Recall-Gain?
Precision-Recall-Gain (PRG) curves improve upon traditional ROC and Precision-Recall curves in heavily imbalanced scenarios (few positives, many negatives). Use PRG when true negatives aren't valuable (e.g., information retrieval, fraud detection).
Advantages:
- Stable under imbalance: Consistent evaluation regardless of class distribution.
- Direct interpretation: AUPRG directly relates to the expected F₁ score.
- Improved model selection: Avoids bias toward inflated metrics caused by imbalance.
- Intuitive thresholding: Convex hull easily identifies optimal thresholds for various Fβ scores.
Installation
pip install sklearn_prg
Usage
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import (roc_curve, precision_recall_curve, auc, average_precision_score)
from sklearn_prg.metrics import precision_recall_gain_curve, average_precision_recall_gain
# Generate imbalanced dataset
X, y = make_classification(n_samples=2000,
n_features=20,
weights=[0.9, 0.1],
flip_y=0.03,
class_sep=0.5,
random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42)
# Train models
clf_rf = RandomForestClassifier(random_state=42).fit(X_train, y_train)
clf_lr = LogisticRegression(max_iter=1000, random_state=42).fit(X_train, y_train)
y_scores_rf = clf_rf.predict_proba(X_test)[:, 1]
y_scores_lr = clf_lr.predict_proba(X_test)[:, 1]
print(f"Average precision recall gain (Logistic Regression): {average_precision_recall_gain(y_test, y_scores_lr):.3f}")
print(f"Average precision recall gain (Random Forest): {average_precision_recall_gain(y_test, y_scores_rf):.3f}")
# ROC curve
fpr_rf, tpr_rf, _ = roc_curve(y_test, y_scores_rf)
fpr_lr, tpr_lr, _ = roc_curve(y_test, y_scores_lr)
# Standard PR curve
prec_rf, rec_rf, _ = precision_recall_curve(y_test, y_scores_rf)
prec_lr, rec_lr, _ = precision_recall_curve(y_test, y_scores_lr)
ap_rf = average_precision_score(y_test, y_scores_rf)
ap_lr = average_precision_score(y_test, y_scores_lr)
# PRG curve
pg_rf, rg_rf = precision_recall_gain_curve(y_test, y_scores_rf)
pg_lr, rg_lr = precision_recall_gain_curve(y_test, y_scores_lr)
auprg_rf = auc(rg_rf, pg_rf)
auprg_lr = auc(rg_lr, pg_lr)
fig, axs = plt.subplots(1, 3, figsize=(18, 6))
# ----- ROC Curve -----
axs[0].plot(fpr_rf, tpr_rf, label=f'Random Forest (AUC={auc(fpr_rf,tpr_rf):.3f})')
axs[0].plot(fpr_lr, tpr_lr, label=f'Logistic (AUC={auc(fpr_lr,tpr_lr):.3f})')
axs[0].plot([0, 1], [0, 1], 'k--', alpha=0.6, label='Random Classifier')
axs[0].set_title('ROC Curves')
axs[0].set_xlabel('False Positive Rate')
axs[0].set_ylabel('True Positive Rate')
axs[0].legend()
axs[0].grid(True)
axs[0].set_aspect('equal', adjustable='box')
# ----- Precision-Recall Curve -----
prevalence = np.mean(y_test)
axs[1].plot(rec_rf, prec_rf, label=f'Random Forest (AP={ap_rf:.3f})')
axs[1].plot(rec_lr, prec_lr, label=f'Logistic (AP={ap_lr:.3f})')
axs[1].axhline(prevalence, linestyle='--', color='black', alpha=0.6, label='Random Classifier')
axs[1].set_title('Precision-Recall Curves')
axs[1].set_xlabel('Recall')
axs[1].set_ylabel('Precision')
axs[1].legend()
axs[1].grid(True)
axs[1].set_aspect('equal', adjustable='box')
# ----- Precision-Recall-Gain Curve -----
axs[2].plot(rg_rf, pg_rf, label=f'Random Forest (AUPRG={auprg_rf:.3f})')
axs[2].plot(rg_lr, pg_lr, label=f'Logistic (AUPRG={auprg_lr:.3f})')
axs[2].plot([1, 0], [0, 1], linestyle='-', color='black', alpha=0.6, label='Always Positive Classifier')
axs[2].set_xlim(0, 1)
axs[2].set_ylim(0, 1)
axs[2].set_xlabel('Recall Gain')
axs[2].set_ylabel('Precision Gain')
axs[2].set_title('Precision-Recall-Gain Curves')
axs[2].legend()
axs[2].grid(True)
axs[2].set_aspect('equal', adjustable='box')
plt.suptitle('ROC, PR, and PRG Curve Comparison', fontsize=16)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()
Citation
If you use this package, please cite the original paper:
@inproceedings{NIPS2015_33e8075e,
author = {Flach, Peter and Kull, Meelis},
booktitle = {Advances in Neural Information Processing Systems},
editor = {C. Cortes and N. Lawrence and D. Lee and M. Sugiyama and R. Garnett},
pages = {},
publisher = {Curran Associates, Inc.},
title = {Precision-Recall-Gain Curves: PR Analysis Done Right},
url = {https://proceedings.neurips.cc/paper_files/paper/2015/file/33e8075e9970de0cfea955afd4644bb2-Paper.pdf},
volume = {28},
year = {2015}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file sklearn_prg-0.1.3.tar.gz.
File metadata
- Download URL: sklearn_prg-0.1.3.tar.gz
- Upload date:
- Size: 5.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e72e642f3defa66e1389818bd1d38b4bc8bdb09d7bd9fd26af1ebb52afc8c139
|
|
| MD5 |
28d56cdf2ae59a9c70d0a857dad4bdd9
|
|
| BLAKE2b-256 |
5ea56f5fde521bef868bfcd8c6f381d621d3275da596146f32b7238b76555642
|
Provenance
The following attestation bundles were made for sklearn_prg-0.1.3.tar.gz:
Publisher:
python-publish.yml on aburkard/sklearn-prg
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
sklearn_prg-0.1.3.tar.gz -
Subject digest:
e72e642f3defa66e1389818bd1d38b4bc8bdb09d7bd9fd26af1ebb52afc8c139 - Sigstore transparency entry: 186270445
- Sigstore integration time:
-
Permalink:
aburkard/sklearn-prg@845d79430b9915d416e5f5e311cf6be57c9b0bf8 -
Branch / Tag:
refs/tags/v0.1.3 - Owner: https://github.com/aburkard
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@845d79430b9915d416e5f5e311cf6be57c9b0bf8 -
Trigger Event:
release
-
Statement type:
File details
Details for the file sklearn_prg-0.1.3-py3-none-any.whl.
File metadata
- Download URL: sklearn_prg-0.1.3-py3-none-any.whl
- Upload date:
- Size: 5.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e1ab4c83a55b96c2009a36d32432ec43b32e7a04e000c07aa4d09944f7e0c7f3
|
|
| MD5 |
282d703885d0d6d28dd061481664a26a
|
|
| BLAKE2b-256 |
f0c8c46606b283cc6808514a52785aa912dcfac4fc3af58d95b7ec4974ef2f57
|
Provenance
The following attestation bundles were made for sklearn_prg-0.1.3-py3-none-any.whl:
Publisher:
python-publish.yml on aburkard/sklearn-prg
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
sklearn_prg-0.1.3-py3-none-any.whl -
Subject digest:
e1ab4c83a55b96c2009a36d32432ec43b32e7a04e000c07aa4d09944f7e0c7f3 - Sigstore transparency entry: 186270449
- Sigstore integration time:
-
Permalink:
aburkard/sklearn-prg@845d79430b9915d416e5f5e311cf6be57c9b0bf8 -
Branch / Tag:
refs/tags/v0.1.3 - Owner: https://github.com/aburkard
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@845d79430b9915d416e5f5e311cf6be57c9b0bf8 -
Trigger Event:
release
-
Statement type: