Skip to main content

Automatic threshold optimization for binary classifiers.

Project description

threshopt

PyPI version License GitHub last commit

Threshold Optimization Library for Binary and Multiclass Classification

threshopt is a lightweight Python library to find the optimal decision threshold for classifiers, improving performance by customizing thresholds instead of relying on defaults.


Features

  • Optimize decision thresholds based on any metric (e.g., accuracy, F1-score, G-Mean, Youden’s J)
  • Supports cross-validation threshold optimization
  • Works with any scikit-learn compatible model
  • Built-in common metrics and support for custom metrics
  • Optional visualization of confusion matrices and prediction score distributions
  • Multiclass and fallback support

Installation

pip install threshopt

Quickstart

Binary classification

from threshopt import optimize_threshold, optimize_threshold_cv, gmean_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import f1_score

# Load data
data = load_breast_cancer()
X, y = data.data, data.target

# Train model
model = RandomForestClassifier(random_state=42)
model.fit(X, y)

# Optimize threshold on the test set
best_thresh, best_val = optimize_threshold(model, X, y, metric=f1_score)
print(f"Best threshold: {best_thresh:.2f}, F1-score: {best_val:.4f}")

# Optimize threshold with cross-validation
best_thresh_cv, best_val_cv = optimize_threshold_cv(model, X, y, metric=gmean_score, cv=5)
print(f"CV best threshold: {best_thresh_cv:.2f}, CV best metric: {best_val_cv:.4f}")

Multiclass classification

from threshopt import optimize_threshold, optimize_threshold_cv, gmean_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.metrics import f1_score

# Load data
data = load_iris()
X, y = data.data, data.target

# Train model
model = RandomForestClassifier(random_state=42)
model.fit(X, y)

# Optimize threshold for multiclass using fallback
best_thresh, best_val = optimize_threshold(model, X, y, metric=f1_score, multiclass=True)
print(f"Best thresholds per class: {best_thresh}, F1-score: {best_val:.4f}")

# Optimize threshold with cross-validation
best_thresh_cv, best_val_cv = optimize_threshold_cv(model, X, y, metric=gmean_score, cv=5, multiclass=True)
print(f"CV best thresholds per class: {best_thresh_cv}, CV best metric: {best_val_cv:.4f}")

Metrics

Included metrics:

  • gmean_score: Geometric Mean of sensitivity and specificity
  • youden_j_stat: Youden’s J statistic (sensitivity + specificity - 1)
  • balanced_acc_score: Balanced Accuracy (wrapper around scikit-learn)

You can also pass any metric function with signature metric(y_true, y_pred).


Contributing

Contributions are welcome! Please open issues or submit pull requests.


License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

threshopt-1.0.0.tar.gz (7.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

threshopt-1.0.0-py3-none-any.whl (6.9 kB view details)

Uploaded Python 3

File details

Details for the file threshopt-1.0.0.tar.gz.

File metadata

  • Download URL: threshopt-1.0.0.tar.gz
  • Upload date:
  • Size: 7.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for threshopt-1.0.0.tar.gz
Algorithm Hash digest
SHA256 bd8dff105cd7cb74825990b041cb19cb82c7d36f4240ca082b8f2f52be1362ee
MD5 ac34225b06ff520e470d74bbce7a7034
BLAKE2b-256 a62ca590898bee978b07f7234db7d1429ada2d1e692a66a9d2c86649c84a328d

See more details on using hashes here.

File details

Details for the file threshopt-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: threshopt-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 6.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for threshopt-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ae5b0eda05a73003d2d8fd63c430bede8345cabdcb0dd287e238f5c49bece7da
MD5 9446bff074611cd19374ae1651065574
BLAKE2b-256 95444309b0c3913d0a5d73fa759b32401d058bb9166bb40bc69ff04eb3fa2dde

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page