Automatic threshold optimization for classifiers.
Project description
Threshold Optimization Library for Binary and Multiclass Classification
threshopt is a lightweight Python library that automatically finds the optimal decision threshold for classification models.
Instead of relying on default thresholds (e.g. 0.5), it optimizes them according to a chosen evaluation metric, improving model performance—especially on imbalanced datasets.
The library is fully compatible with scikit-learn estimators and supports both binary and multiclass (one-vs-rest) scenarios.
Requirements
- Python >= 3.8
- scikit-learn
- numpy
- matplotlib
Installation
pip install threshopt
Features
- Automatic optimization of decision thresholds
- Metric-driven optimization (any
sklearn-style metric or custom metric) - Cross-validated threshold optimization
- Works with any scikit-learn compatible classifier
- Built-in metrics for imbalanced classification
- Optional visualization utilities
- Multiclass support via one-vs-rest logic
Quickstart
Binary classification
from threshopt import optimize_threshold, optimize_threshold_cv, gmean_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import f1_score
X, y = load_breast_cancer(return_X_y=True)
model = RandomForestClassifier(random_state=42)
model.fit(X, y)
# Optimize threshold on the full dataset
best_thresh, best_metric = optimize_threshold(
model, X, y,
metric=f1_score,
plot=False, cm=False
)
print(f"Best threshold: {best_thresh:.2f}")
print(f"Best F1-score: {best_metric:.4f}")
# Optimize threshold using cross-validation
best_thresh_cv, best_metric_cv = optimize_threshold_cv(
model, X, y,
metric=gmean_score,
cv=5,
plot=False, cm=False
)
print(f"CV best threshold: {best_thresh_cv:.2f}")
print(f"CV best G-Mean: {best_metric_cv:.4f}")
Multiclass classification
from threshopt import optimize_threshold, optimize_threshold_cv
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.metrics import f1_score
X, y = load_iris(return_X_y=True)
model = RandomForestClassifier(random_state=42)
model.fit(X, y)
# Optimize thresholds (one per class)
best_thresh, best_metric = optimize_threshold(
model, X, y,
metric=f1_score,
multiclass=True,
plot=False, cm=False
)
print("Best thresholds per class:", best_thresh)
print(f"Best F1-score: {best_metric:.4f}")
# Cross-validated multiclass optimization
best_thresh_cv, best_metric_cv = optimize_threshold_cv(
model, X, y,
metric=f1_score,
cv=5,
multiclass=True,
plot=False, cm=False
)
print("CV best thresholds per class:", best_thresh_cv)
print(f"CV best F1-score: {best_metric_cv:.4f}")
API
optimize_threshold
optimize_threshold(model, X, y_true, metric, multiclass=False,
use_predict_if_no_proba=False, plot=True, cm=True, report=True)
Optimizes the decision threshold on the full dataset.
| Parameter | Type | Description |
|---|---|---|
model |
estimator | Trained classifier with predict_proba or decision_function |
X |
array-like | Feature matrix |
y_true |
array-like | True labels |
metric |
callable | Metric function (y_true, y_pred) -> float |
multiclass |
bool | If True, performs one-vs-rest optimization |
use_predict_if_no_proba |
bool | Fall back to predict() if no probability scores available |
plot |
bool | If True, plots probability distributions |
cm |
bool | If True, displays confusion matrix |
report |
bool | If True, prints classification report |
Returns: (thresholds, best_scores) — scalar for binary, array for multiclass.
optimize_threshold_cv
optimize_threshold_cv(model, X, y_true, metric, cv=5, multiclass=False,
plot=True, cm=True, report=True)
Optimizes the decision threshold using cross-validation, reducing the risk of overfitting to a specific train/test split.
| Parameter | Type | Description |
|---|---|---|
model |
estimator | Trained classifier with predict_proba or decision_function |
X |
array-like | Feature matrix |
y_true |
array-like | True labels |
metric |
callable | Metric function (y_true, y_pred) -> float |
cv |
int | Number of cross-validation folds |
multiclass |
bool | If True, performs one-vs-rest optimization |
plot |
bool | If True, plots probability distributions |
cm |
bool | If True, displays confusion matrix |
report |
bool | If True, prints classification report |
Returns: (thresholds, best_scores) — scalar for binary, array for multiclass.
Metrics
Built-in metrics
| Metric | Signature | Description |
|---|---|---|
gmean_score |
(y_true, y_pred) -> float |
Geometric mean of sensitivity and specificity |
youden_j_stat |
(y_true, y_pred) -> float |
Sensitivity + specificity - 1 |
balanced_acc_score |
(y_true, y_pred) -> float |
Balanced accuracy (wrapper around scikit-learn) |
Custom metrics
Any function with the following signature can be passed as metric:
metric(y_true, y_pred) # return a float
Contributing
Contributions are welcome! Please follow these steps:
- Fork the repository
- Create a feature branch (
git checkout -b feature/my-feature) - Commit your changes (
git commit -m 'Add my feature') - Push to the branch (
git push origin feature/my-feature) - Open a pull request
For bug reports or feature requests, please open an issue.
License
This project is licensed under the Apache License 2.0 — see the LICENSE file for details.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file threshopt-1.1.1.tar.gz.
File metadata
- Download URL: threshopt-1.1.1.tar.gz
- Upload date:
- Size: 13.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
07cdf5a0478aaff549138a067ef868bab262c6cb31cbfec621858c124f2547ea
|
|
| MD5 |
b4f88497e6b13ef4ba4cc8fc8a0fa4e3
|
|
| BLAKE2b-256 |
f4473922f0f8390c2570777f41c2a5518403571f00c7453c76910430fda96c6c
|
File details
Details for the file threshopt-1.1.1-py3-none-any.whl.
File metadata
- Download URL: threshopt-1.1.1-py3-none-any.whl
- Upload date:
- Size: 11.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4f91d76c2af88171f6d3f2d061089b676849cd2ccaab11fb23e3e7251ec1a1e5
|
|
| MD5 |
3789a622037b7e7a3760d1803041a8b9
|
|
| BLAKE2b-256 |
247352a92c6e9e9722096ce2022104da7aaa6df68b18e67787a1f97d9e1d964f
|