Skip to main content

Automatic threshold optimization for binary classifiers.

Project description

threshopt

PyPI version License GitHub last commit

Threshold Optimization Library for Binary Classification

threshopt is a lightweight Python library designed to help find the optimal decision threshold for binary classifiers, improving model performance by customizing the threshold instead of relying on the default 0.5.


Features

  • Optimize decision thresholds based on any metric (e.g. accuracy, F1-score, G-Mean, Youden’s J)
  • Supports cross-validation threshold optimization for robust model tuning
  • Easy integration with any scikit-learn compatible model
  • Built-in common metrics and ability to use custom metrics
  • Visualize confusion matrices and prediction score distributions (optional)

Installation

bash pip install -e .

Install in editable mode from the project root directory.

Quickstart

from threshopt import optimize_threshold, optimize_threshold_cv, gmean_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import f1_score

# Load data
data = load_breast_cancer()
X, y = data.data, data.target

# Train model
model = RandomForestClassifier(random_state=42)
model.fit(X, y)

# Optimize threshold on the test set
best_thresh, best_val = optimize_threshold(model, X, y, metric=f1_score)
print(f"Best threshold: {best_thresh:.2f}, F1-score: {best_val:.4f}")

# Optimize threshold with cross-validation
best_thresh_cv, best_val_cv = optimize_threshold_cv(model, X, y, metric=gmean_score, cv=5)
print(f"CV best threshold: {best_thresh_cv:.2f}, CV best metric: {best_val_cv:.4f}")


# Load data

data = load_breast_cancer() X, y = data.data, data.target

# Train model

model = RandomForestClassifier(random_state=42) model.fit(X, y)

# Optimize threshold on the test set

best_thresh, best_val = optimize_threshold(model, X, y, metric=f1_score) print(f"Best threshold: {best_thresh:.2f}, F1-score: {best_val:.4f}")

# Optimize threshold with cross-validation

best_thresh_cv, best_val_cv = optimize_threshold_cv(model, X, y, metric=gmean_score, cv=5) print(f"CV best threshold: {best_thresh_cv:.2f}, CV best metric: {best_val_cv:.4f}")

Metrics

Included metrics:

  • gmean_score: Geometric Mean of sensitivity and specificity
  • youden_j_stat: Youden’s J statistic (sensitivity + specificity - 1)
  • balanced_acc_score: Balanced Accuracy (wrapper around scikit-learn)

You can also pass any metric function with signature metric(y_true, y_pred).


Contributing

Contributions are welcome! Please open issues or submit pull requests.


License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

threshopt-0.2.0.tar.gz (7.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

threshopt-0.2.0-py3-none-any.whl (6.9 kB view details)

Uploaded Python 3

File details

Details for the file threshopt-0.2.0.tar.gz.

File metadata

  • Download URL: threshopt-0.2.0.tar.gz
  • Upload date:
  • Size: 7.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for threshopt-0.2.0.tar.gz
Algorithm Hash digest
SHA256 2265b204d697bf0bab165a20a2a944ae9c5eecadb4378a69c9c7184103f58868
MD5 b2f9d1918dc42bbccd70aba0c9cee53e
BLAKE2b-256 2937a4989603dc07e91a49d8d49690e9d8d94df13872aaf8a3350b62b890f4a2

See more details on using hashes here.

File details

Details for the file threshopt-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: threshopt-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 6.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for threshopt-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2762c580d2baa59b30fc5d24eebe9b46349779e8b47db6e39d1ed95f4fa42c67
MD5 e1989eca704b6d32b0ab5caf36dfb169
BLAKE2b-256 b7780c10f8ce5ca6d9fe791f00b21dd3803b0d819a8852554a78eddf8a3e297b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page