Automatic threshold optimization for binary classifiers.
Project description
threshopt
Threshold Optimization Library for Binary and Multiclass Classification
threshopt is a lightweight Python library to find the optimal decision threshold for classifiers, improving performance by customizing thresholds instead of relying on defaults.
Features
- Optimize decision thresholds based on any metric (e.g., accuracy, F1-score, G-Mean, Youden’s J)
- Supports cross-validation threshold optimization
- Works with any scikit-learn compatible model
- Built-in common metrics and support for custom metrics
- Optional visualization of confusion matrices and prediction score distributions
- Multiclass and fallback support
Installation
pip install threshopt
Quickstart
Binary classification
from threshopt import optimize_threshold, optimize_threshold_cv, gmean_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import f1_score
# Load data
data = load_breast_cancer()
X, y = data.data, data.target
# Train model
model = RandomForestClassifier(random_state=42)
model.fit(X, y)
# Optimize threshold on the test set
best_thresh, best_val = optimize_threshold(model, X, y, metric=f1_score)
print(f"Best threshold: {best_thresh:.2f}, F1-score: {best_val:.4f}")
# Optimize threshold with cross-validation
best_thresh_cv, best_val_cv = optimize_threshold_cv(model, X, y, metric=gmean_score, cv=5)
print(f"CV best threshold: {best_thresh_cv:.2f}, CV best metric: {best_val_cv:.4f}")
Multiclass classification
from threshopt import optimize_threshold, optimize_threshold_cv, gmean_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.metrics import f1_score
# Load data
data = load_iris()
X, y = data.data, data.target
# Train model
model = RandomForestClassifier(random_state=42)
model.fit(X, y)
# Optimize threshold for multiclass using fallback
best_thresh, best_val = optimize_threshold(model, X, y, metric=f1_score, multiclass=True)
print(f"Best thresholds per class: {best_thresh}, F1-score: {best_val:.4f}")
# Optimize threshold with cross-validation
best_thresh_cv, best_val_cv = optimize_threshold_cv(model, X, y, metric=gmean_score, cv=5, multiclass=True)
print(f"CV best thresholds per class: {best_thresh_cv}, CV best metric: {best_val_cv:.4f}")
Metrics
Included metrics:
gmean_score: Geometric Mean of sensitivity and specificityyouden_j_stat: Youden’s J statistic (sensitivity + specificity - 1)balanced_acc_score: Balanced Accuracy (wrapper around scikit-learn)
You can also pass any metric function with signature metric(y_true, y_pred).
Contributing
Contributions are welcome! Please open issues or submit pull requests.
License
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file threshopt-1.0.0.tar.gz.
File metadata
- Download URL: threshopt-1.0.0.tar.gz
- Upload date:
- Size: 7.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bd8dff105cd7cb74825990b041cb19cb82c7d36f4240ca082b8f2f52be1362ee
|
|
| MD5 |
ac34225b06ff520e470d74bbce7a7034
|
|
| BLAKE2b-256 |
a62ca590898bee978b07f7234db7d1429ada2d1e692a66a9d2c86649c84a328d
|
File details
Details for the file threshopt-1.0.0-py3-none-any.whl.
File metadata
- Download URL: threshopt-1.0.0-py3-none-any.whl
- Upload date:
- Size: 6.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ae5b0eda05a73003d2d8fd63c430bede8345cabdcb0dd287e238f5c49bece7da
|
|
| MD5 |
9446bff074611cd19374ae1651065574
|
|
| BLAKE2b-256 |
95444309b0c3913d0a5d73fa759b32401d058bb9166bb40bc69ff04eb3fa2dde
|