Tune Decision Thresholds
Project description
Tune Decision Thresholds
tunethresholds is a small scikit-learn-style utility for tuning multiclass
classification decision thresholds after a model has already been trained. It
wraps a classifier that exposes predict_proba() and classes_, learns one
multiplicative weight per class on a validation set, and uses those weighted
scores to choose labels.
Why It Exists
Many classifiers predict the class with the largest raw probability. That default can be suboptimal when classes have different costs, frequencies, or validation-set behavior. This package keeps the underlying model fixed and adjusts only the final decision rule: each class probability is multiplied by a learned class weight, then the largest adjusted value wins.
How It Works
The main API is AdjustedProbabilitiesDerivedModel.
predict_proba(X)calls the wrapped model'spredict_proba(X)and multiplies each class column by its learned weight.predict(X)returns the class whose adjusted probability is largest.adjust_model_decision_thresholds(...)finds class weights withscipy.optimize.differential_evolution, maximizing a validation-set metric. The default metric issklearn.metrics.matthews_corrcoef; custom metrics must acceptscore_func(y_true, y_pred).
The adjusted probabilities are intentionally not renormalized. They may not sum to 1, and they should be treated as adjusted scores rather than calibrated probabilities.
Installation
pip install tunethresholds
The package requires Python 3.8+ and depends on NumPy, SciPy, scikit-learn, and extendanything.
Usage
from sklearn.metrics import accuracy_score
from tunethresholds import AdjustedProbabilitiesDerivedModel
# clf is an already-fitted classifier with predict_proba() and classes_.
adjusted_clf = AdjustedProbabilitiesDerivedModel.adjust_model_decision_thresholds(
model=clf,
X_validation=X_val,
y_validation_true=y_val,
score_func=accuracy_score,
)
y_pred = adjusted_clf.predict(X_test)
adjusted_scores = adjusted_clf.predict_proba(X_test)
If validation probabilities have already been computed, pass them directly
instead of X_validation:
adjusted_clf = AdjustedProbabilitiesDerivedModel.adjust_model_decision_thresholds(
model=clf,
predicted_probabilities_validation=clf.predict_proba(X_val),
y_validation_true=y_val,
)
Important Behavior
- The wrapped model must expose
predict_proba()andclasses_. - Classes absent from
y_validation_trueare assigned a fixed weight of0. - Present-class weights are optimized within
[1e-5, 1.0]. - Multiplying class probabilities by weights does not change ROC AUC (it
preserves the per-class ranking of examples). But because the adjusted
predict_proba()output is not normalized, tools that require rows to sum to 1, including scikit-learn's multiclassroc_auc_score, will reject it. Do not renormalize to work around this: renormalizing can change the rankings and therefore the ROC AUC. - The optimizer is run on validation data only; the underlying classifier is not retrained.
Development
pip install -r requirements_dev.txt
pip install -e .
pytest
Additional local commands are available through the Makefile, including
make test, make lint, make coverage, and make docs.
Changelog
0.0.1
- First release on PyPI.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tunethresholds-0.0.2.tar.gz.
File metadata
- Download URL: tunethresholds-0.0.2.tar.gz
- Upload date:
- Size: 16.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f45e041fd0b3a6d758299b37760c39f496d62363cea449ee1b7d259c6e1ed4bd
|
|
| MD5 |
726f95571b1e376cf63b2a004f22d2bb
|
|
| BLAKE2b-256 |
2a98406613301fb2d57d37467b4f6f0c942764aee04b0ef6d5bcbc3e1d8f8982
|
File details
Details for the file tunethresholds-0.0.2-py2.py3-none-any.whl.
File metadata
- Download URL: tunethresholds-0.0.2-py2.py3-none-any.whl
- Upload date:
- Size: 7.1 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e3907830f9dcfc4abc9138c9d433eadab9fbc3618b615d5bf8f46ecd1fb4dbcd
|
|
| MD5 |
75cd944722d367491ecdb6f84095c30c
|
|
| BLAKE2b-256 |
83c5448852f8bd109cc4f0972acb33f6c60f66fe36c06eea24fc84ec35786fd6
|