Skip to main content

An optimizer for finding optimal thresholds in ordinal classification problems

Project description

Optimized Rounder

License: MIT

An efficient optimizer for finding optimal thresholds in ordinal classification problems. This package uses Optuna for efficient threshold search with support for cross-validation and multiple evaluation metrics.

Installation

pip install optimized-rounder

Features

  • Threshold Optimization: Find optimal thresholds to convert continuous predictions to discrete classes
  • Multiple Metrics: Support for various evaluation metrics including quadratic kappa, linear kappa, RMSE, accuracy, and F1 scores
  • Cross-Validation: Built-in support for K-fold and stratified cross-validation
  • Efficient Search: Uses Optuna for efficient Bayesian optimization of thresholds
  • Comprehensive Evaluation: Evaluate models using multiple metrics simultaneously

Quick Start

from oprounder import OptimizedRounder
import numpy as np
from sklearn.metrics import cohen_kappa_score

# Generate synthetic data
np.random.seed(42)
n_classes = 4
n_samples = 1000
y_true = np.random.randint(0, n_classes, size=n_samples)
output = y_true + np.random.normal(0, 0.9, size=n_samples) # dummy model output

# Initialize and fit the optimizer
rounder = OptimizedRounder(n_classes=n_classes, n_trials=100)
rounder.fit(output, y_true)

# Get the optimal thresholds
print(f'Optimal thresholds: {rounder.thresholds}')

# Make predictions
y_pred = rounder.predict(output)
kappa = cohen_kappa_score(y_true, y_pred, weights='quadratic')
print(f'Optimal Quadratic kappa: {kappa:.4f}')
y_pred_default = rounder.apply_thresholds(output, rounder.default_thresholds) # [0.5, 1.5, 2.5, 3.5]
kappa_default = cohen_kappa_score(y_true, y_pred_default, weights='quadratic')
print(f'Default Quadratic kappa: {kappa_default:.4f}')

Advanced Usage

With Cross-Validation

# Use 5-fold stratified cross-validation
rounder = OptimizedRounder(
    n_classes=4,
    n_trials=200,
    cv=5,
    stratified=True,
    metric='quadratic_kappa',
    verbose=True
)

rounder.fit(output, y_true)
print(f'CV Results: {rounder.cv_results_}')

Using Different Metrics

# Optimize for F1 weighted score
rounder = OptimizedRounder(
    n_classes=4,
    n_trials=200,
    metric='f1_weighted'
)

rounder.fit(output, y_true)

# Comprehensive evaluation
output_val = y_true + np.random.normal(0, 0.8, size=n_samples)
metrics = rounder.evaluate(output_val, y_true)
for metric_name, value in metrics.items():
    print(f"{metric_name}: {value:.4f}")

API Reference

OptimizedRounder

OptimizedRounder(
    n_classes=None,
    n_trials=200,
    cv=None,
    stratified=True,
    metric='quadratic_kappa',
    verbose=False,
    random_state=42
)

Parameters

  • n_classes: Number of target classes (0, 1, 2, ..., n_classes-1)
  • n_trials: Number of optimization trials for Optuna
  • cv: Number of cross-validation folds or a CV splitter object
  • stratified: Whether to use stratified CV (only when cv is an integer)
  • metric: Metric to optimize ('quadratic_kappa', 'linear_kappa', 'rmse', 'accuracy', 'f1_macro', 'f1_weighted', 'f1_micro')
  • verbose: Whether to display Optuna's optimization progress
  • random_state: Random seed for reproducibility

Methods

  • fit(X, y): Find optimal thresholds using Optuna optimization
  • predict(X): Convert continuous predictions to discrete classes using optimal thresholds
  • fit_predict(X, y): Train the optimizer and return predictions in one step
  • coefficients(): Get the optimal thresholds found during training
  • evaluate(X, y): Evaluate the model on multiple metrics

Use Cases

  • Regression to Classification: Convert regression outputs to discrete classes
  • Ordinal Classification: Optimize thresholds for ordinal targets
  • Ensemble Calibration: Calibrate probability outputs from ensemble models
  • Competition Metrics: Optimize directly for competition metrics like quadratic kappa

License

This project is licensed under the MIT License - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

optimized_rounder-0.1.2.tar.gz (6.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

optimized_rounder-0.1.2-py3-none-any.whl (6.8 kB view details)

Uploaded Python 3

File details

Details for the file optimized_rounder-0.1.2.tar.gz.

File metadata

  • Download URL: optimized_rounder-0.1.2.tar.gz
  • Upload date:
  • Size: 6.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.14

File hashes

Hashes for optimized_rounder-0.1.2.tar.gz
Algorithm Hash digest
SHA256 4b13ecb4078aff5dc7c7ddc87bf1cbfa50c27e608d2c25c7d58f44ade768d414
MD5 93d429a6336685b8860fb76b98c2698c
BLAKE2b-256 208d39751f1b576c989db7df1a7480630ba7fdbf24481a0d750f18688d74d68a

See more details on using hashes here.

File details

Details for the file optimized_rounder-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for optimized_rounder-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 2a2427e143913566727ddf4b47fec65a0dbf033464e926ac0dfd8936b899b10c
MD5 34ded113d8c3fb51f63d278857edd007
BLAKE2b-256 9bf8e7c9b521c643cbe478716e3c198ce7bd2628de87ec2562e32f27089c5630

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page