Automatic threshold optimization for binary classifiers.
Project description
threshopt
Threshold Optimization Library for Binary Classification
threshopt is a lightweight Python library designed to help find the optimal decision threshold for binary classifiers, improving model performance by customizing the threshold instead of relying on the default 0.5.
Features
- Optimize decision thresholds based on any metric (e.g. accuracy, F1-score, G-Mean, Youden’s J)\
- Supports cross-validation threshold optimization for robust model tuning\
- Easy integration with any scikit-learn compatible model\
- Built-in common metrics and ability to use custom metrics\
- Visualize confusion matrices and prediction score distributions (optional)
Installation
bash pip install -e .
Install in editable mode from the project root directory.
Quickstart
from threshopt import optimize_threshold, optimize_threshold_cv, gmean_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import f1_score
# Load data
data = load_breast_cancer()
X, y = data.data, data.target
# Train model
model = RandomForestClassifier(random_state=42)
model.fit(X, y)
# Optimize threshold on the test set
best_thresh, best_val = optimize_threshold(model, X, y, metric=f1_score)
print(f"Best threshold: {best_thresh:.2f}, F1-score: {best_val:.4f}")
# Optimize threshold with cross-validation
best_thresh_cv, best_val_cv = optimize_threshold_cv(model, X, y, metric=gmean_score, cv=5)
print(f"CV best threshold: {best_thresh_cv:.2f}, CV best metric: {best_val_cv:.4f}")
# Load data
data = load_breast_cancer() X, y = data.data, data.target
# Train model
model = RandomForestClassifier(random_state=42) model.fit(X, y)
# Optimize threshold on the test set
best_thresh, best_val = optimize_threshold(model, X, y, metric=f1_score) print(f"Best threshold: {best_thresh:.2f}, F1-score: {best_val:.4f}")
# Optimize threshold with cross-validation
best_thresh_cv, best_val_cv = optimize_threshold_cv(model, X, y, metric=gmean_score, cv=5) print(f"CV best threshold: {best_thresh_cv:.2f}, CV best metric: {best_val_cv:.4f}")
Metrics
Included metrics:
gmean_score: Geometric Mean of sensitivity and specificity\youden_j_stat: Youden’s J statistic (sensitivity + specificity - 1)\balanced_acc_score: Balanced Accuracy (wrapper around scikit-learn)
You can also pass any metric function with signature metric(y_true, y_pred).
Contributing
Contributions are welcome! Please open issues or submit pull requests.
License
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file threshopt-0.1.0.tar.gz.
File metadata
- Download URL: threshopt-0.1.0.tar.gz
- Upload date:
- Size: 6.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e85550f73fbaff5a846dc920b4a5463ae03b4aa09ccb6e18795cfbf53fb216b6
|
|
| MD5 |
46445a1e0af6a694443f3fcf0b9d2f45
|
|
| BLAKE2b-256 |
94b62c993317c9bc540f98811ee8f43d7f4408ef43a8461d1f8db51912f04984
|
File details
Details for the file threshopt-0.1.0-py3-none-any.whl.
File metadata
- Download URL: threshopt-0.1.0-py3-none-any.whl
- Upload date:
- Size: 6.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c6594ce0f948c6be95de0e98b68870bdd036df2c4581157adcc025a7a2197159
|
|
| MD5 |
f2e50eec30e5a8b8d9bd8c7af8d75357
|
|
| BLAKE2b-256 |
5848d652c276acb785970c0567a08d0b3ac502ba86a8c561e4354de78d60f45f
|