Utilities for computing optimal classification cutoffs for binary and multi-class
Project description
Optimal Classification Cut-Offs
Probabilistic classifiers output per-class probabilities, and fixed cutoffs such as 0.5 rarely maximize metrics like accuracy or the F\ :sub:1 score.
This package provides utilities to select optimal probability cutoffs for each class, supporting both multi-class and binary classifiers.
Optimization methods include brute-force search, numerical techniques, and gradient-based approaches.
Binary thresholding at a single cutoff remains fully supported as a special case.
Quick start
from optimal_cutoffs import ThresholdOptimizer
# true binary labels and predicted probabilities
y_true = ...
y_prob = ...
optimizer = ThresholdOptimizer(objective="f1")
optimizer.fit(y_true, y_prob)
y_pred = optimizer.predict(y_prob)
API
get_confusion_matrix(true_labs, pred_prob, threshold)
- Purpose: Compute confusion-matrix counts for a threshold.
- Args: arrays of true labels and probabilities, plus the decision threshold.
- Returns:
(tp, tn, fp, fn)counts.
register_metric(name=None, func=None)
- Purpose: Add a metric function to the global registry.
- Args: optional metric name and callable; can also be used as a decorator.
- Returns: the registered function or decorator.
register_metrics(metrics)
- Purpose: Register multiple metric functions at once.
- Args: dictionary mapping names to callables.
- Returns:
None.
get_probability(true_labs, pred_prob, objective='accuracy', verbose=False)
- Purpose: Brute-force search for the threshold that maximizes accuracy or F\ :sub:
1. - Args: true labels, predicted probabilities, metric name, and verbosity flag.
- Returns: optimal threshold.
get_optimal_threshold(true_labs, pred_prob, metric='f1', method='smart_brute')
- Purpose: Optimize any registered metric using different strategies
(brute force,
minimize, orgradient). - Args: true labels, probabilities, metric name, and optimization method.
- Returns: optimal threshold.
cv_threshold_optimization(true_labs, pred_prob, metric='f1', method='smart_brute', cv=5, random_state=None)
- Purpose: Estimate thresholds via cross-validation and report per-fold scores.
- Returns: arrays of thresholds and scores.
nested_cv_threshold_optimization(true_labs, pred_prob, metric='f1', method='smart_brute', inner_cv=5, outer_cv=5, random_state=None)
- Purpose: Perform nested cross-validation for threshold estimation and unbiased performance evaluation.
- Returns: arrays of outer-fold thresholds and scores.
ThresholdOptimizer(objective='accuracy', verbose=False)
- Purpose: High-level wrapper with
fit/predictmethods. - Args: metric name and verbosity flag.
- Returns: fitted instance with
threshold_attribute after callingfit.
Examples
Authors
Suriyan Laohaprapanon and Gaurav Sood
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file optimal_classification_cutoffs-0.1.0.tar.gz.
File metadata
- Download URL: optimal_classification_cutoffs-0.1.0.tar.gz
- Upload date:
- Size: 7.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f43e3a6849f112f983db5e80deda7486a77755b19d82c91b79f7aadab673276b
|
|
| MD5 |
510cd32401de8e3af5138fdbc65bf6ce
|
|
| BLAKE2b-256 |
d10f5f32dc35ff2a784953d0e0470df1d2412355e5c7f6e331d5f1086b486dd9
|
File details
Details for the file optimal_classification_cutoffs-0.1.0-py3-none-any.whl.
File metadata
- Download URL: optimal_classification_cutoffs-0.1.0-py3-none-any.whl
- Upload date:
- Size: 9.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
56b7c79649ceb75b19272e67fe519968689bccf2a1a5a0e2da11157c4d516867
|
|
| MD5 |
569e604c243762d980f7037f43a9f438
|
|
| BLAKE2b-256 |
d45abcc8249de9bb112405d4dddba1803a2281491b2d8254125c1abba20f47cf
|