Skip to main content

Automatic threshold optimization for classifiers.

Project description

Logo

PyPI version License GitHub last commit

Threshold Optimization Library for Binary and Multiclass Classification

threshopt is a lightweight Python library that automatically finds the optimal decision threshold for classification models.
Instead of relying on default thresholds (e.g. 0.5), it optimizes them according to a chosen evaluation metric, improving model performance—especially on imbalanced datasets.

The library is fully compatible with scikit-learn estimators and supports both binary and multiclass (fallback-based) scenarios.


Features

  • Automatic optimization of decision thresholds
  • Metric-driven optimization (any sklearn-style metric or custom metric)
  • Cross-validated threshold optimization
  • Works with any scikit-learn compatible classifier
  • Built-in metrics for imbalanced classification
  • Optional visualization utilities
  • Multiclass support via one-vs-rest fallback logic

Installation

pip install threshopt

Quickstart

Binary classification

from threshopt import optimize_threshold, optimize_threshold_cv, gmean_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import f1_score

# Load data
X, y = load_breast_cancer(return_X_y=True)

# Train model
model = RandomForestClassifier(random_state=42)
model.fit(X, y)

# Optimize threshold on the full dataset
best_thresh, best_metric = optimize_threshold(
    model,
    X,
    y,
    metric=f1_score
)

print(f"Best threshold: {best_thresh:.2f}")
print(f"Best F1-score: {best_metric:.4f}")

# Optimize threshold using cross-validation
best_thresh_cv, best_metric_cv = optimize_threshold_cv(
    model,
    X,
    y,
    metric=gmean_score,
    cv=5
)

print(f"CV best threshold: {best_thresh_cv:.2f}")
print(f"CV best G-Mean: {best_metric_cv:.4f}")

Multiclass classification

from threshopt import optimize_threshold, optimize_threshold_cv
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.metrics import f1_score

# Load data
X, y = load_iris(return_X_y=True)

# Train model
model = RandomForestClassifier(random_state=42)
model.fit(X, y)

# Optimize thresholds (one per class)
best_thresh, best_metric = optimize_threshold(
    model,
    X,
    y,
    metric=f1_score,
    multiclass=True
)

print("Best thresholds per class:", best_thresh)
print(f"Best F1-score: {best_metric:.4f}")

# Cross-validated multiclass optimization
best_thresh_cv, best_metric_cv = optimize_threshold_cv(
    model,
    X,
    y,
    metric=f1_score,
    cv=5,
    multiclass=True
)

print("CV best thresholds per class:", best_thresh_cv)
print(f"CV best F1-score: {best_metric_cv:.4f}")

Metrics

Included metrics:

  • gmean_score: Geometric Mean of sensitivity and specificity
  • youden_j_stat: Youden’s J statistic (sensitivity + specificity - 1)
  • balanced_acc_score: Balanced Accuracy (wrapper around scikit-learn)

You can also pass any metric function with signature:

 `metric(y_true, y_pred)` -> float 

Contributing

Contributions are welcome! Please open issues or submit pull requests.


License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

threshopt-1.1.0.tar.gz (10.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

threshopt-1.1.0-py3-none-any.whl (9.8 kB view details)

Uploaded Python 3

File details

Details for the file threshopt-1.1.0.tar.gz.

File metadata

  • Download URL: threshopt-1.1.0.tar.gz
  • Upload date:
  • Size: 10.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for threshopt-1.1.0.tar.gz
Algorithm Hash digest
SHA256 f2095a3b80922479999c8e4cf31ac2e4009a635b2b95b91a6c36180170774260
MD5 1f8aa5eef04723bf2e35cb99b5b6d03c
BLAKE2b-256 e3087291a7b6c656a93f9ddf724a1618c4082134f2bf3d5a7b7fdb3b55f19738

See more details on using hashes here.

File details

Details for the file threshopt-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: threshopt-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 9.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for threshopt-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b323db200d4a8f61590b8245df9809d563c5783496a1061bcef0d39f157f90d1
MD5 6a588cb1b1885cd17fc2c0c277bf9757
BLAKE2b-256 9ecad1635f2275d2980175b567fd67ec53ff256513736e53aaee0cd9615bfc35

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page