Skip to main content

Automatic threshold optimization for binary classifiers.

Project description

threshopt

Threshold Optimization Library for Binary Classification

threshopt is a lightweight Python library designed to help find the optimal decision threshold for binary classifiers, improving model performance by customizing the threshold instead of relying on the default 0.5.


Features

  • Optimize decision thresholds based on any metric (e.g. accuracy, F1-score, G-Mean, Youden’s J)\
  • Supports cross-validation threshold optimization for robust model tuning\
  • Easy integration with any scikit-learn compatible model\
  • Built-in common metrics and ability to use custom metrics\
  • Visualize confusion matrices and prediction score distributions (optional)

Installation

bash pip install -e .

Install in editable mode from the project root directory.

Quickstart

from threshopt import optimize_threshold, optimize_threshold_cv, gmean_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import f1_score

# Load data
data = load_breast_cancer()
X, y = data.data, data.target

# Train model
model = RandomForestClassifier(random_state=42)
model.fit(X, y)

# Optimize threshold on the test set
best_thresh, best_val = optimize_threshold(model, X, y, metric=f1_score)
print(f"Best threshold: {best_thresh:.2f}, F1-score: {best_val:.4f}")

# Optimize threshold with cross-validation
best_thresh_cv, best_val_cv = optimize_threshold_cv(model, X, y, metric=gmean_score, cv=5)
print(f"CV best threshold: {best_thresh_cv:.2f}, CV best metric: {best_val_cv:.4f}")


# Load data

data = load_breast_cancer() X, y = data.data, data.target

# Train model

model = RandomForestClassifier(random_state=42) model.fit(X, y)

# Optimize threshold on the test set

best_thresh, best_val = optimize_threshold(model, X, y, metric=f1_score) print(f"Best threshold: {best_thresh:.2f}, F1-score: {best_val:.4f}")

# Optimize threshold with cross-validation

best_thresh_cv, best_val_cv = optimize_threshold_cv(model, X, y, metric=gmean_score, cv=5) print(f"CV best threshold: {best_thresh_cv:.2f}, CV best metric: {best_val_cv:.4f}")

Metrics

Included metrics:

  • gmean_score: Geometric Mean of sensitivity and specificity\
  • youden_j_stat: Youden’s J statistic (sensitivity + specificity - 1)\
  • balanced_acc_score: Balanced Accuracy (wrapper around scikit-learn)

You can also pass any metric function with signature metric(y_true, y_pred).


Contributing

Contributions are welcome! Please open issues or submit pull requests.


License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

threshopt-0.1.0.tar.gz (6.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

threshopt-0.1.0-py3-none-any.whl (6.0 kB view details)

Uploaded Python 3

File details

Details for the file threshopt-0.1.0.tar.gz.

File metadata

  • Download URL: threshopt-0.1.0.tar.gz
  • Upload date:
  • Size: 6.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.2

File hashes

Hashes for threshopt-0.1.0.tar.gz
Algorithm Hash digest
SHA256 e85550f73fbaff5a846dc920b4a5463ae03b4aa09ccb6e18795cfbf53fb216b6
MD5 46445a1e0af6a694443f3fcf0b9d2f45
BLAKE2b-256 94b62c993317c9bc540f98811ee8f43d7f4408ef43a8461d1f8db51912f04984

See more details on using hashes here.

File details

Details for the file threshopt-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: threshopt-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 6.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.2

File hashes

Hashes for threshopt-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c6594ce0f948c6be95de0e98b68870bdd036df2c4581157adcc025a7a2197159
MD5 f2e50eec30e5a8b8d9bd8c7af8d75357
BLAKE2b-256 5848d652c276acb785970c0567a08d0b3ac502ba86a8c561e4354de78d60f45f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page