Automatic threshold optimization for classifiers.
Project description
Threshold Optimization Library for Binary and Multiclass Classification
threshopt is a lightweight Python library that automatically finds the optimal decision threshold for classification models.
Instead of relying on default thresholds (e.g. 0.5), it optimizes them according to a chosen evaluation metric, improving model performance—especially on imbalanced datasets.
The library is fully compatible with scikit-learn estimators and supports both binary and multiclass (fallback-based) scenarios.
Features
- Automatic optimization of decision thresholds
- Metric-driven optimization (any
sklearn-style metric or custom metric) - Cross-validated threshold optimization
- Works with any scikit-learn compatible classifier
- Built-in metrics for imbalanced classification
- Optional visualization utilities
- Multiclass support via one-vs-rest fallback logic
Installation
pip install threshopt
Quickstart
Binary classification
from threshopt import optimize_threshold, optimize_threshold_cv, gmean_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import f1_score
# Load data
X, y = load_breast_cancer(return_X_y=True)
# Train model
model = RandomForestClassifier(random_state=42)
model.fit(X, y)
# Optimize threshold on the full dataset
best_thresh, best_metric = optimize_threshold(
model,
X,
y,
metric=f1_score
)
print(f"Best threshold: {best_thresh:.2f}")
print(f"Best F1-score: {best_metric:.4f}")
# Optimize threshold using cross-validation
best_thresh_cv, best_metric_cv = optimize_threshold_cv(
model,
X,
y,
metric=gmean_score,
cv=5
)
print(f"CV best threshold: {best_thresh_cv:.2f}")
print(f"CV best G-Mean: {best_metric_cv:.4f}")
Multiclass classification
from threshopt import optimize_threshold, optimize_threshold_cv
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.metrics import f1_score
# Load data
X, y = load_iris(return_X_y=True)
# Train model
model = RandomForestClassifier(random_state=42)
model.fit(X, y)
# Optimize thresholds (one per class)
best_thresh, best_metric = optimize_threshold(
model,
X,
y,
metric=f1_score,
multiclass=True
)
print("Best thresholds per class:", best_thresh)
print(f"Best F1-score: {best_metric:.4f}")
# Cross-validated multiclass optimization
best_thresh_cv, best_metric_cv = optimize_threshold_cv(
model,
X,
y,
metric=f1_score,
cv=5,
multiclass=True
)
print("CV best thresholds per class:", best_thresh_cv)
print(f"CV best F1-score: {best_metric_cv:.4f}")
Metrics
Included metrics:
gmean_score: Geometric Mean of sensitivity and specificityyouden_j_stat: Youden’s J statistic (sensitivity + specificity - 1)balanced_acc_score: Balanced Accuracy (wrapper around scikit-learn)
You can also pass any metric function with signature:
`metric(y_true, y_pred)` -> float
Contributing
Contributions are welcome! Please open issues or submit pull requests.
License
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file threshopt-1.1.0.tar.gz.
File metadata
- Download URL: threshopt-1.1.0.tar.gz
- Upload date:
- Size: 10.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f2095a3b80922479999c8e4cf31ac2e4009a635b2b95b91a6c36180170774260
|
|
| MD5 |
1f8aa5eef04723bf2e35cb99b5b6d03c
|
|
| BLAKE2b-256 |
e3087291a7b6c656a93f9ddf724a1618c4082134f2bf3d5a7b7fdb3b55f19738
|
File details
Details for the file threshopt-1.1.0-py3-none-any.whl.
File metadata
- Download URL: threshopt-1.1.0-py3-none-any.whl
- Upload date:
- Size: 9.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b323db200d4a8f61590b8245df9809d563c5783496a1061bcef0d39f157f90d1
|
|
| MD5 |
6a588cb1b1885cd17fc2c0c277bf9757
|
|
| BLAKE2b-256 |
9ecad1635f2275d2980175b567fd67ec53ff256513736e53aaee0cd9615bfc35
|