Skip to main content

Efficient and reliable machine learning with a reject option and selectiveness

Project description

PyPi Downloads License CodeFactor Python package workflow Docs PythonVersion Black Ruff Twitter

👁 Overview

scikit-fallback is a scikit-learn-compatible Python package for selective machine learning.

TL;DR

🔙 Augment your classification pipelines with skfb.estimators such as AnomalyFallbackClassifier and ThresholdFallbackClassifier to allow them to abstain from predictions in cases of uncertanty or anomaly.
📊 Inspect their performance by calculating combined, prediction-rejection metrics such as predict_reject_recall_score, or visualizing distributions of confidence scores with PairedHistogramDisplay, and other tools from skfb.metrics.
🎶 Combine your costly ensembles with RoutingClassifier or in ThresholdCascadeClassifierCV and other skfb.ensemble meta-estimators to streamline inference while elevating model performance.
📒See documentation, tutorials, and examples for more details and motivation.

🤔 Why scikit-fallback?

To fall back (on) means to retreat from making predictions, to rely on other tools for support. scikit-fallback offers functionality to enhance your machine learning solutions with selectiveness and a reject option.

Machine Learning with Rejections

To allow your classification pipelines to abstain from predictions, you can wrap them with a rejector. Training a rejector means both fitting your model and learning to accept or reject predictions. Evaluation of a rejector depends on fallback mode (inference with or without fallback labels) and measures the ability of the rejector to both accept correct predictions and reject ambiguous ones.

For example, skfb.estimators.ThresholdFallbackClassifierCV fits a base estimator and then finds the best confidence threshold s.t. predictions w/ maximum probability lower that this are rejected:

>>> import numpy as np
>>> from sklearn.linear_model import LogisticRegression
>>> from skfb.estimators import ThresholdFallbackClassifierCV
>>> X = np.array([[0, 0], [4, 4], [1, 1], [3, 3], [2.5, 2], [2., 2.5]])
>>> y = np.array([0, 1, 0, 1, 0, 1])
>>> # Train LogisticRegression and let it fallback based on confidence scores.
>>> rejector = ThresholdFallbackClassifierCV(
...     estimator=LogisticRegression(random_state=0),
...     thresholds=(0.5, 0.55, 0.6, 0.65),
...     ambiguity_threshold=0.0,
...     cv=2,
...     fallback_label=-1,
...     fallback_mode="store").fit(X, y)
>>> # If probability is lower than this, predict `fallback_label` = -1.
>>> rejector.threshold_
0.55
>>> # Make predictions and see which inputs were accepted or rejected.
>>> y_pred = rejector.predict(X)
>>> # If `fallback_mode` == `"store", always accept but also mask rejections.
>>> y_pred, y_pred.get_dense_fallback_mask()
(FBNDArray([0, 1, 0, 1, 1, 1]),
    array([False, False, False, False,  True, False]))
>>> # This allows calculation of combined metrics (e.g., predict-reject accuracy).
>>> rejector.score(X, y)
1.0
>>> # Otherwise, allow fallbacks
>>> rejector.set_params(fallback_mode="return").predict(X)
array([ 0,  1,  0,  1, -1,  1])
>>> # and calculate accuracy only on accepted samples,
>>> rejector.score(X, y)
1.0
>>> # or just switch off rejections and fallback to a plain LogisticRegression.
>>> rejector.set_params(fallback_mode="ignore").score(X, y)
0.8333333333333334

See Estimators for more examples of rejection meta-estimators and Combined Metrics for evaluation and inspection tools.

Dynamic Ensembling

While common ensembling methods such as voting and stacking aim to boost predictive performance, they also increase inference costs as a result of output aggregations. Alternatively, we could learn to choose which individual model or subset of models in an ensemble should make a decision, thereby reducing inference overhead while bargaining, or sometimes even improving, predictive performance.

For example, skfb.ensemble.ThresholdCascadeClassifierCV builds a cascade from a sequence of models arranged by their inference costs (and basically, by their performance - e.g., from weakest but fastest to strongest but slowest) and learns confidence thresholds that determine whether the current model in the sequence makes a prediction or defers to the next model based on its confidence score for a given input:

>>> from skfb.ensemble import ThresholdCascadeClassifierCV
>>> from sklearn.datasets import make_classification
>>> from sklearn.ensemble import HistGradientBoostingClassifier
>>> X, y = make_classification(
...     n_samples=1_000, n_features=100, n_redundant=97, class_sep=0.1, flip_y=0.05,
...     random_state=0)
>>> weak = HistGradientBoostingClassifier(max_iter=10, max_depth=2, random_state=0)
>>> okay = HistGradientBoostingClassifier(max_iter=20, max_depth=3, random_state=0)
>>> buff = HistGradientBoostingClassifier(max_iter=99, max_depth=4, random_state=0)
>>> # Train all models and learn thresholds per model s.t. if the current model's max
>>> # confidence score is lower, it defers the decision to the next in the cascade.
>>> cascading = ThresholdCascadeClassifierCV(
...     estimators=[weak, okay, buff],
...     costs=[1.1, 1.2, 1.99],
...     cv_thresholds=5,
...     cv=3,
...     scoring="accuracy",
...     return_earray=True,
...     response_method="predict_proba").fit(X, y)
>>> # Best thresholds for `weak` and `okay`
>>> # (`buff` will always predict if `weak` and `okay` fall back):
>>> cascading.best_thresholds_
array([0.6125, 0.8375])
>>> # If `return_earray` is True, predictions will be of type `skfb.core.FBNDArray`,
>>> # which store `acceptance_rate` w/ the ratios of accepted inputs per model.
>>> cascading.predict(X).acceptance_rates
array([0.659, 0.003, 0.338])

🏗 Installation

scikit-fallback requires:

  • Python (>= 3.9,< 3.14)
  • scikit-learn (>=1.0)
  • numpy
  • scipy
  • matplotlib (>=3.0) (optional)

and along with the requirements can be installed via pip :

pip install scikit-fallback

Note: when using Python 3.9 w/ scikit-learn>=1.7, subclassing of BaseEstimator and scikit-learn mixins might result in runtime errors related to __sklearn_tags__. Also, if you have scikit-learn<=1.2, you will see warnings about the unavailability of nested or general parameter validation, which you can ignore.

🔗 Links

  1. Documentation
  2. Medium Series
  3. Examples & Notebooks: examples/ and https://kaggle.com/sshadylov
  4. Related Research:
    1. Hendrickx, K., Perini, L., Van der Plas, D. et al. Machine learning with a reject option: a survey. Mach Learn 113, 3073–3110 (2024).
    2. Wittawat Jitkrittum, Neha Gupta, Aditya K Menon, Harikrishna Narasimhan, Ankit Rawat, and Sanjiv Kumar. When does confidence-based cascade deferral suffice? NeurIPS, 36, 2024.
    3. And more (coming soon).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

scikit_fallback-0.2.0.post1.tar.gz (44.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

scikit_fallback-0.2.0.post1-py3-none-any.whl (57.4 kB view details)

Uploaded Python 3

File details

Details for the file scikit_fallback-0.2.0.post1.tar.gz.

File metadata

  • Download URL: scikit_fallback-0.2.0.post1.tar.gz
  • Upload date:
  • Size: 44.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for scikit_fallback-0.2.0.post1.tar.gz
Algorithm Hash digest
SHA256 402b3a89fe76bc1f163d27386ddcdacc24df37db16b19cd58fe4c739bcc586c0
MD5 69dc8845676a16717c5c9bebc0c09490
BLAKE2b-256 919f8a3605d17acd0f40ffb719caf86fac3b6c50bf3f6e7ddd79212c7ca6423d

See more details on using hashes here.

File details

Details for the file scikit_fallback-0.2.0.post1-py3-none-any.whl.

File metadata

File hashes

Hashes for scikit_fallback-0.2.0.post1-py3-none-any.whl
Algorithm Hash digest
SHA256 24821d319fba4af7366910f913effa1954193682223914ee29dd4534688ba32f
MD5 fb0e1bdb8ff004d7e335bea81ca94916
BLAKE2b-256 80907b38f2683983ed1d75e9a59a8cd944aad147a9a67ea57ed5d27f32b67794

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page