An optimizer for finding optimal thresholds in ordinal classification problems
Project description
Optimized Rounder
An efficient optimizer for finding optimal thresholds in ordinal classification problems. This package uses Optuna for efficient threshold search with support for cross-validation and multiple evaluation metrics.
Installation
pip install optimized-rounder
Features
- Threshold Optimization: Find optimal thresholds to convert continuous predictions to discrete classes
- Multiple Metrics: Support for various evaluation metrics including quadratic kappa, linear kappa, RMSE, accuracy, and F1 scores
- Cross-Validation: Built-in support for K-fold and stratified cross-validation
- Efficient Search: Uses Optuna for efficient Bayesian optimization of thresholds
- Comprehensive Evaluation: Evaluate models using multiple metrics simultaneously
Quick Start
from optimized_rounder import OptimizedRounder
import numpy as np
from sklearn.metrics import cohen_kappa_score
# Generate synthetic data
np.random.seed(42)
n_classes = 4
n_samples = 1000
y_true = np.random.randint(0, n_classes, size=n_samples)
X_train = y_true + np.random.normal(0, 0.8, size=n_samples)
# Initialize and fit the optimizer
rounder = OptimizedRounder(n_classes=n_classes, n_trials=100)
rounder.fit(X_train, y_true)
# Get the optimal thresholds
print(f'Optimal thresholds: {rounder.thresholds}')
# Make predictions
y_pred = rounder.predict(X_train)
kappa = cohen_kappa_score(y_true, y_pred, weights='quadratic')
print(f'Quadratic kappa: {kappa:.4f}')
Advanced Usage
With Cross-Validation
# Use 5-fold stratified cross-validation
rounder = OptimizedRounder(
n_classes=4,
n_trials=200,
cv=5,
stratified=True,
metric='quadratic_kappa',
verbose=True
)
rounder.fit(X_train, y_true)
print(f'CV Results: {rounder.cv_results_}')
Using Different Metrics
# Optimize for F1 weighted score
rounder = OptimizedRounder(
n_classes=4,
n_trials=200,
metric='f1_weighted'
)
rounder.fit(X_train, y_true)
# Comprehensive evaluation
metrics = rounder.evaluate(X_val, y_true)
for metric_name, value in metrics.items():
print(f"{metric_name}: {value:.4f}")
API Reference
OptimizedRounder
OptimizedRounder(
n_classes=None,
n_trials=200,
cv=None,
stratified=True,
metric='quadratic_kappa',
verbose=False,
random_state=42
)
Parameters
- n_classes: Number of target classes (0, 1, 2, ..., n_classes-1)
- n_trials: Number of optimization trials for Optuna
- cv: Number of cross-validation folds or a CV splitter object
- stratified: Whether to use stratified CV (only when cv is an integer)
- metric: Metric to optimize ('quadratic_kappa', 'linear_kappa', 'rmse', 'accuracy', 'f1_macro', 'f1_weighted', 'f1_micro')
- verbose: Whether to display Optuna's optimization progress
- random_state: Random seed for reproducibility
Methods
- fit(X, y): Find optimal thresholds using Optuna optimization
- predict(X): Convert continuous predictions to discrete classes using optimal thresholds
- fit_predict(X, y): Train the optimizer and return predictions in one step
- coefficients(): Get the optimal thresholds found during training
- evaluate(X, y): Evaluate the model on multiple metrics
Use Cases
- Regression to Classification: Convert regression outputs to discrete classes
- Ordinal Classification: Optimize thresholds for ordinal targets
- Ensemble Calibration: Calibrate probability outputs from ensemble models
- Competition Metrics: Optimize directly for competition metrics like quadratic kappa
License
This project is licensed under the MIT License - see the LICENSE file for details.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file optimized_rounder-0.1.0.tar.gz.
File metadata
- Download URL: optimized_rounder-0.1.0.tar.gz
- Upload date:
- Size: 6.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.10.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
11bcd0e1d8326df8443f42f97f9104700559ed58606842b771438c4b81ae57d0
|
|
| MD5 |
22ec6dd70c6bfeb7a65c67eac99f430d
|
|
| BLAKE2b-256 |
40f6528a1a6168140edc4ac92cffe98e906c7e90cf7363af775a00d426532aee
|
File details
Details for the file optimized_rounder-0.1.0-py3-none-any.whl.
File metadata
- Download URL: optimized_rounder-0.1.0-py3-none-any.whl
- Upload date:
- Size: 6.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.10.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
37078e67e3273f7e76616e479d12bbb26510506d500ae9b94d68f195c6e933a4
|
|
| MD5 |
da706234073361257f94d7d6b357e293
|
|
| BLAKE2b-256 |
07765e42f0ad0a8bc9074bc8bbd2796af9913a6cef9448b8824d7f908a377272
|