Skip to main content

Precision-Recall-Gain curves and metrics for scikit-learn

Project description

sklearn_prg

A Python module implementing Precision-Recall-Gain (PRG) curves and metrics compatible with scikit-learn.

Why Precision-Recall-Gain?

Precision-Recall-Gain (PRG) curves improve upon traditional ROC and Precision-Recall curves in heavily imbalanced scenarios (few positives, many negatives). Use PRG when true negatives aren't valuable (e.g., information retrieval, fraud detection).

Advantages:

  • Stable under imbalance: Consistent evaluation regardless of class distribution.
  • Direct interpretation: AUPRG directly relates to the expected F₁ score.
  • Improved model selection: Avoids bias toward inflated metrics caused by imbalance.
  • Intuitive thresholding: Convex hull easily identifies optimal thresholds for various Fβ scores.

Installation

pip install sklearn_prg

Usage

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import (roc_curve, precision_recall_curve, auc, average_precision_score)
from sklearn_prg.metrics import precision_recall_gain_curve, average_precision_recall_gain

# Generate imbalanced dataset
X, y = make_classification(n_samples=2000,
                           n_features=20,
                           weights=[0.9, 0.1],
                           flip_y=0.03,
                           class_sep=0.5,
                           random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42)

# Train models
clf_rf = RandomForestClassifier(random_state=42).fit(X_train, y_train)
clf_lr = LogisticRegression(max_iter=1000, random_state=42).fit(X_train, y_train)
y_scores_rf = clf_rf.predict_proba(X_test)[:, 1]
y_scores_lr = clf_lr.predict_proba(X_test)[:, 1]

print(f"Average precision recall gain (Logistic Regression): {average_precision_recall_gain(y_test, y_scores_lr):.3f}")
print(f"Average precision recall gain (Random Forest): {average_precision_recall_gain(y_test, y_scores_rf):.3f}")

# ROC curve
fpr_rf, tpr_rf, _ = roc_curve(y_test, y_scores_rf)
fpr_lr, tpr_lr, _ = roc_curve(y_test, y_scores_lr)

# Standard PR curve
prec_rf, rec_rf, _ = precision_recall_curve(y_test, y_scores_rf)
prec_lr, rec_lr, _ = precision_recall_curve(y_test, y_scores_lr)
ap_rf = average_precision_score(y_test, y_scores_rf)
ap_lr = average_precision_score(y_test, y_scores_lr)

# PRG curve
pg_rf, rg_rf = precision_recall_gain_curve(y_test, y_scores_rf)
pg_lr, rg_lr = precision_recall_gain_curve(y_test, y_scores_lr)
auprg_rf = auc(rg_rf, pg_rf)
auprg_lr = auc(rg_lr, pg_lr)

fig, axs = plt.subplots(1, 3, figsize=(18, 6))

# ----- ROC Curve -----
axs[0].plot(fpr_rf, tpr_rf, label=f'Random Forest (AUC={auc(fpr_rf,tpr_rf):.3f})')
axs[0].plot(fpr_lr, tpr_lr, label=f'Logistic (AUC={auc(fpr_lr,tpr_lr):.3f})')
axs[0].plot([0, 1], [0, 1], 'k--', alpha=0.6, label='Random Classifier')
axs[0].set_title('ROC Curves')
axs[0].set_xlabel('False Positive Rate')
axs[0].set_ylabel('True Positive Rate')
axs[0].legend()
axs[0].grid(True)
axs[0].set_aspect('equal', adjustable='box')

# ----- Precision-Recall Curve -----
prevalence = np.mean(y_test)
axs[1].plot(rec_rf, prec_rf, label=f'Random Forest (AP={ap_rf:.3f})')
axs[1].plot(rec_lr, prec_lr, label=f'Logistic (AP={ap_lr:.3f})')
axs[1].axhline(prevalence, linestyle='--', color='black', alpha=0.6, label='Random Classifier')
axs[1].set_title('Precision-Recall Curves')
axs[1].set_xlabel('Recall')
axs[1].set_ylabel('Precision')
axs[1].legend()
axs[1].grid(True)
axs[1].set_aspect('equal', adjustable='box')

# ----- Precision-Recall-Gain Curve -----
axs[2].plot(rg_rf, pg_rf, label=f'Random Forest (AUPRG={auprg_rf:.3f})')
axs[2].plot(rg_lr, pg_lr, label=f'Logistic (AUPRG={auprg_lr:.3f})')
axs[2].plot([1, 0], [0, 1], linestyle='-', color='black', alpha=0.6, label='Always Positive Classifier')
axs[2].set_xlim(0, 1)
axs[2].set_ylim(0, 1)
axs[2].set_xlabel('Recall Gain')
axs[2].set_ylabel('Precision Gain')
axs[2].set_title('Precision-Recall-Gain Curves')
axs[2].legend()
axs[2].grid(True)
axs[2].set_aspect('equal', adjustable='box')

plt.suptitle('ROC, PR, and PRG Curve Comparison', fontsize=16)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

A comparison of ROC, PR, and PRG curves

Citation

If you use this package, please cite the original paper:

@inproceedings{NIPS2015_33e8075e,
 author = {Flach, Peter and Kull, Meelis},
 booktitle = {Advances in Neural Information Processing Systems},
 editor = {C. Cortes and N. Lawrence and D. Lee and M. Sugiyama and R. Garnett},
 pages = {},
 publisher = {Curran Associates, Inc.},
 title = {Precision-Recall-Gain Curves: PR Analysis Done Right},
 url = {https://proceedings.neurips.cc/paper_files/paper/2015/file/33e8075e9970de0cfea955afd4644bb2-Paper.pdf},
 volume = {28},
 year = {2015}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sklearn_prg-0.1.3.tar.gz (5.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sklearn_prg-0.1.3-py3-none-any.whl (5.0 kB view details)

Uploaded Python 3

File details

Details for the file sklearn_prg-0.1.3.tar.gz.

File metadata

  • Download URL: sklearn_prg-0.1.3.tar.gz
  • Upload date:
  • Size: 5.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for sklearn_prg-0.1.3.tar.gz
Algorithm Hash digest
SHA256 e72e642f3defa66e1389818bd1d38b4bc8bdb09d7bd9fd26af1ebb52afc8c139
MD5 28d56cdf2ae59a9c70d0a857dad4bdd9
BLAKE2b-256 5ea56f5fde521bef868bfcd8c6f381d621d3275da596146f32b7238b76555642

See more details on using hashes here.

Provenance

The following attestation bundles were made for sklearn_prg-0.1.3.tar.gz:

Publisher: python-publish.yml on aburkard/sklearn-prg

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file sklearn_prg-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: sklearn_prg-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 5.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for sklearn_prg-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 e1ab4c83a55b96c2009a36d32432ec43b32e7a04e000c07aa4d09944f7e0c7f3
MD5 282d703885d0d6d28dd061481664a26a
BLAKE2b-256 f0c8c46606b283cc6808514a52785aa912dcfac4fc3af58d95b7ec4974ef2f57

See more details on using hashes here.

Provenance

The following attestation bundles were made for sklearn_prg-0.1.3-py3-none-any.whl:

Publisher: python-publish.yml on aburkard/sklearn-prg

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page