Skip to main content

Automatic threshold optimization for binary classifiers.

Project description

threshopt

Threshold Optimization Library for Binary Classification

threshopt is a lightweight Python library designed to help find the optimal decision threshold for binary classifiers, improving model performance by customizing the threshold instead of relying on the default 0.5.


Features

  • Optimize decision thresholds based on any metric (e.g. accuracy, F1-score, G-Mean, Youden’s J)
  • Supports cross-validation threshold optimization for robust model tuning
  • Easy integration with any scikit-learn compatible model
  • Built-in common metrics and ability to use custom metrics
  • Visualize confusion matrices and prediction score distributions (optional)

Installation

bash pip install -e .

Install in editable mode from the project root directory.

Quickstart

from threshopt import optimize_threshold, optimize_threshold_cv, gmean_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import f1_score

# Load data
data = load_breast_cancer()
X, y = data.data, data.target

# Train model
model = RandomForestClassifier(random_state=42)
model.fit(X, y)

# Optimize threshold on the test set
best_thresh, best_val = optimize_threshold(model, X, y, metric=f1_score)
print(f"Best threshold: {best_thresh:.2f}, F1-score: {best_val:.4f}")

# Optimize threshold with cross-validation
best_thresh_cv, best_val_cv = optimize_threshold_cv(model, X, y, metric=gmean_score, cv=5)
print(f"CV best threshold: {best_thresh_cv:.2f}, CV best metric: {best_val_cv:.4f}")


# Load data

data = load_breast_cancer() X, y = data.data, data.target

# Train model

model = RandomForestClassifier(random_state=42) model.fit(X, y)

# Optimize threshold on the test set

best_thresh, best_val = optimize_threshold(model, X, y, metric=f1_score) print(f"Best threshold: {best_thresh:.2f}, F1-score: {best_val:.4f}")

# Optimize threshold with cross-validation

best_thresh_cv, best_val_cv = optimize_threshold_cv(model, X, y, metric=gmean_score, cv=5) print(f"CV best threshold: {best_thresh_cv:.2f}, CV best metric: {best_val_cv:.4f}")

Metrics

Included metrics:

  • gmean_score: Geometric Mean of sensitivity and specificity\
  • youden_j_stat: Youden’s J statistic (sensitivity + specificity - 1)\
  • balanced_acc_score: Balanced Accuracy (wrapper around scikit-learn)

You can also pass any metric function with signature metric(y_true, y_pred).


Contributing

Contributions are welcome! Please open issues or submit pull requests.


License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

threshopt-0.1.1.tar.gz (6.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

threshopt-0.1.1-py3-none-any.whl (6.0 kB view details)

Uploaded Python 3

File details

Details for the file threshopt-0.1.1.tar.gz.

File metadata

  • Download URL: threshopt-0.1.1.tar.gz
  • Upload date:
  • Size: 6.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.2

File hashes

Hashes for threshopt-0.1.1.tar.gz
Algorithm Hash digest
SHA256 6a474fa0789bb39169b34278a4e685a7e33a3ba95a50825a582a947db88d9e9d
MD5 6064fa2df296f1dce1034a99b6a72952
BLAKE2b-256 3ba322d26f2bf5e14fbfb6052e6bc6e320491b0197a2a85f5786c9024ce48041

See more details on using hashes here.

File details

Details for the file threshopt-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: threshopt-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 6.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.2

File hashes

Hashes for threshopt-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 9ab0e1ccb938fcfc68b76b28b5aef0cd58f765bc6b93de43cc7faba979b8afc5
MD5 d4f52dae7c0be3240183feadc82d4991
BLAKE2b-256 8a631b335931353beef54c46d678cd8f8145e010310dfb422d034000aafafa67

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page