Skip to main content

Metrics for conditional coverage estimation

Project description

Covmetrics: conditional coverage metrics

This package (PyTorch-based) currently contains different conditional coverage metrics, including our metric ERT (Excess risk of the target coverage).

It accompanies our papers Conditional Coverage Diagnostics for Conformal Prediction. Please cite us if you use this repository for research purposes.

Installation

Work in progress: for now you can use the ERT file to use our metrics, it will be pip installable soon!

Using conditional coverage metrics

For a quick usage, you can evaluate a metric as follows:

from covmetrics import ERT 

ERT_value = ERT().evaluate(x, cover, alpha)

Where the object "x" is a feature vector of shape (n_samples, n_features) (numpy, torch or dataframe), and cover is a vector of shape (n_samples,) with 0's or 1's

The default classifier used to classify the outputs is a LightGBM classifier. You can change this by replacing the model class of the classifier:

from covmetrics import ERT 
from sklearn.linear_model import LogisticRegression

ERT_estimator = ERT(model_cls=LogisticRegression)

We recommend using our k-folds pre-implemented version to evaluate the conditional miscoverage by doing (default value is 5):

ERT_value = ERT_estimator.evaluate(x_test, cover_test, alpha, n_splits = 5)

But you can choose between training the classifier with some data and using it on other doing the following:

ERT_estimator.fit(x_train, cover_train)
ERT_value = ERT_estimator.evaluate(x_test, cover_test, alpha, n_splits=0)

Modifying the loss function

The default loss used to evaluate the classifier provides a lower bound on the $L_1$-ERT. You can change the loss by doing :

ERT_estimator.evaluate(x_test, cover_test, alpha, loss=your_loss)

The package already provides several losses functions to evaluate your models. You can import them as follows:

from covmetrics.losses import (
    brier_score,
    logloss,
    L1_miscoverage,
    brier_score_over,
    L1_miscoverage_over,
    logloss_over,
    brier_score_under,
    logloss_under,
    L1_miscoverage_under
)

If you want to evaluate more losses at the same time, you can use

ERT_value = ERT_estimator.evaluate_multiple_losses(x_test, cover_test, alpha, all_losses = List_of_all_your_losses)

Which returns a dictionnary with all evaluated losses . By default, if all_losses=None, the metrics evaluated are the $L_1$-ERT, $L_2$-ERT and KL-ERT.

Using a custom classifier

You can also use your own classifier with ERT. To do this, define a class with the following methods:

  • __init__(self, **model_kwargs): Initialize your model with any required parameters.
  • fit(self, X, y, **fit_kwargs): Train the model on your data.
  • predict_proba(self, X): Return the predicted probabilities for each class.

Once your class is defined, you can instantiate and evaluate it with ERT as follows:

ERT_estimator = ERT(your_model_class, one_argument=k, another_argument="p")
ERT_value = ERT_estimator.evaluate(x_test, cover_test, alpha, one_fit_argument=m, another_fit_argument="M")

Evaluating Conditional Coverage Rules

You can evaluate conditional coverage rules by providing alpha as an array that matches the type and length of cover. For example, if cover is a PyTorch tensor, alpha should also be a tensor; if cover is a NumPy array, alpha should be a NumPy array.

Each value of alpha must be between 0 and 1 and represents the conditional miscoverage level, meaning:

[ \mathbb{P}(Y \in C(X) \mid X) = 1 - \alpha(X) ]

Example usage:

ERT_estimator = ERT()
tab_alpha = torch.ones(len(cover_test)) * 0.9 # if cover_test is a torch.Tensor
ERT_value = ERT_estimator.evaluate(x_test, cover_test, alpha=tab_alpha)

Other metrics

Other metrics implemented metrics are:

  • WSC (Worst slab coverage).
  • FSC (Feature-stratified coverage).
  • CovGap.
  • WeightedCovGap.
  • SSC (Size-stratified coverage).
  • EOC (Equal opportunity of coverage).
  • Pearson's Correlation.
  • HSIC's Correlation.

The WSC metric is a vectorized version of the original github : Original code from https://github.com/Shai128/oqr.

from covmetrics import WSC 

WSC_value = WSC().evaluate(x_test, cover_test)

For the CovGap metric, or the WeightedCovGap one, it can be estimated as:

from covmetrics import CovGap 

CovGap_value = CovGap().evaluate(x_test, cover_test, alpha=alpha, weighted=True)

Similar import can be used to use the metrics SSC, FSC, EOC, HSIC and PearsonCorrelation.

The HSIC metric has been built upon the original code from: https://github.com/danielgreenfeld3/XIC.

Contributors

  • Sacha Braun
  • David Holzmüller

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

covmetrics-0.1.0.tar.gz (16.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

covmetrics-0.1.0-py3-none-any.whl (20.1 kB view details)

Uploaded Python 3

File details

Details for the file covmetrics-0.1.0.tar.gz.

File metadata

  • Download URL: covmetrics-0.1.0.tar.gz
  • Upload date:
  • Size: 16.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: Hatch/1.16.2 cpython/3.11.7 HTTPX/0.28.1

File hashes

Hashes for covmetrics-0.1.0.tar.gz
Algorithm Hash digest
SHA256 677eb989e4d9694faea3b7579a6a43a17c1fc8a61f0d6782b7fb1c86a8b07ad2
MD5 e340b6f04f3e9dc8a6278ae957251dc5
BLAKE2b-256 dc10f9d855e2b4d9c7d4545023395213bb011a1074faa9e8927b5ecae8a1a442

See more details on using hashes here.

File details

Details for the file covmetrics-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: covmetrics-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 20.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: Hatch/1.16.2 cpython/3.11.7 HTTPX/0.28.1

File hashes

Hashes for covmetrics-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a8e5ded6ef744f71bc537a085eae7a42d14098c482d36b1ac9e33f70744a3acd
MD5 bad3e929623c08cb01a9b7520baba891
BLAKE2b-256 1da320d3661e089ad188dd36b57b911769ff6baf8f1c723aae3638fd96465aa4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page