Skip to main content

Automatic threshold optimization for classifiers.

Project description

Logo

PyPI version License GitHub last commit

Threshold Optimization Library for Binary and Multiclass Classification

threshopt is a lightweight Python library that automatically finds the optimal decision threshold for classification models.
Instead of relying on default thresholds (e.g. 0.5), it optimizes them according to a chosen evaluation metric, improving model performance—especially on imbalanced datasets.

The library is fully compatible with scikit-learn estimators and supports both binary and multiclass (one-vs-rest) scenarios.


Requirements

  • Python >= 3.8
  • scikit-learn
  • numpy
  • matplotlib

Installation

pip install threshopt

Features

  • Automatic optimization of decision thresholds
  • Metric-driven optimization (any sklearn-style metric or custom metric)
  • Cross-validated threshold optimization
  • Works with any scikit-learn compatible classifier
  • Built-in metrics for imbalanced classification
  • Optional visualization utilities
  • Multiclass support via one-vs-rest logic

Quickstart

Binary classification

from threshopt import optimize_threshold, optimize_threshold_cv, gmean_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import f1_score

X, y = load_breast_cancer(return_X_y=True)

model = RandomForestClassifier(random_state=42)
model.fit(X, y)

# Optimize threshold on the full dataset
best_thresh, best_metric = optimize_threshold(
    model, X, y,
    metric=f1_score,
    plot=False, cm=False
)
print(f"Best threshold: {best_thresh:.2f}")
print(f"Best F1-score: {best_metric:.4f}")

# Optimize threshold using cross-validation
best_thresh_cv, best_metric_cv = optimize_threshold_cv(
    model, X, y,
    metric=gmean_score,
    cv=5,
    plot=False, cm=False
)
print(f"CV best threshold: {best_thresh_cv:.2f}")
print(f"CV best G-Mean: {best_metric_cv:.4f}")

Multiclass classification

from threshopt import optimize_threshold, optimize_threshold_cv
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.metrics import f1_score

X, y = load_iris(return_X_y=True)

model = RandomForestClassifier(random_state=42)
model.fit(X, y)

# Optimize thresholds (one per class)
best_thresh, best_metric = optimize_threshold(
    model, X, y,
    metric=f1_score,
    multiclass=True,
    plot=False, cm=False
)
print("Best thresholds per class:", best_thresh)
print(f"Best F1-score: {best_metric:.4f}")

# Cross-validated multiclass optimization
best_thresh_cv, best_metric_cv = optimize_threshold_cv(
    model, X, y,
    metric=f1_score,
    cv=5,
    multiclass=True,
    plot=False, cm=False
)
print("CV best thresholds per class:", best_thresh_cv)
print(f"CV best F1-score: {best_metric_cv:.4f}")

API

optimize_threshold

optimize_threshold(model, X, y_true, metric, multiclass=False,
                   use_predict_if_no_proba=False, plot=True, cm=True, report=True)

Optimizes the decision threshold on the full dataset.

Parameter Type Description
model estimator Trained classifier with predict_proba or decision_function
X array-like Feature matrix
y_true array-like True labels
metric callable Metric function (y_true, y_pred) -> float
multiclass bool If True, performs one-vs-rest optimization
use_predict_if_no_proba bool Fall back to predict() if no probability scores available
plot bool If True, plots probability distributions
cm bool If True, displays confusion matrix
report bool If True, prints classification report

Returns: (thresholds, best_scores) — scalar for binary, array for multiclass.


optimize_threshold_cv

optimize_threshold_cv(model, X, y_true, metric, cv=5, multiclass=False,
                      plot=True, cm=True, report=True)

Optimizes the decision threshold using cross-validation, reducing the risk of overfitting to a specific train/test split.

Parameter Type Description
model estimator Trained classifier with predict_proba or decision_function
X array-like Feature matrix
y_true array-like True labels
metric callable Metric function (y_true, y_pred) -> float
cv int Number of cross-validation folds
multiclass bool If True, performs one-vs-rest optimization
plot bool If True, plots probability distributions
cm bool If True, displays confusion matrix
report bool If True, prints classification report

Returns: (thresholds, best_scores) — scalar for binary, array for multiclass.


Metrics

Built-in metrics

Metric Signature Description
gmean_score (y_true, y_pred) -> float Geometric mean of sensitivity and specificity
youden_j_stat (y_true, y_pred) -> float Sensitivity + specificity - 1
balanced_acc_score (y_true, y_pred) -> float Balanced accuracy (wrapper around scikit-learn)

Custom metrics

Any function with the following signature can be passed as metric:

 metric(y_true, y_pred) # return a float 

Contributing

Contributions are welcome! Please follow these steps:

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/my-feature)
  3. Commit your changes (git commit -m 'Add my feature')
  4. Push to the branch (git push origin feature/my-feature)
  5. Open a pull request

For bug reports or feature requests, please open an issue.


License

This project is licensed under the Apache License 2.0 — see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

threshopt-1.1.1.tar.gz (13.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

threshopt-1.1.1-py3-none-any.whl (11.2 kB view details)

Uploaded Python 3

File details

Details for the file threshopt-1.1.1.tar.gz.

File metadata

  • Download URL: threshopt-1.1.1.tar.gz
  • Upload date:
  • Size: 13.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for threshopt-1.1.1.tar.gz
Algorithm Hash digest
SHA256 07cdf5a0478aaff549138a067ef868bab262c6cb31cbfec621858c124f2547ea
MD5 b4f88497e6b13ef4ba4cc8fc8a0fa4e3
BLAKE2b-256 f4473922f0f8390c2570777f41c2a5518403571f00c7453c76910430fda96c6c

See more details on using hashes here.

File details

Details for the file threshopt-1.1.1-py3-none-any.whl.

File metadata

  • Download URL: threshopt-1.1.1-py3-none-any.whl
  • Upload date:
  • Size: 11.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for threshopt-1.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4f91d76c2af88171f6d3f2d061089b676849cd2ccaab11fb23e3e7251ec1a1e5
MD5 3789a622037b7e7a3760d1803041a8b9
BLAKE2b-256 247352a92c6e9e9722096ce2022104da7aaa6df68b18e67787a1f97d9e1d964f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page