Skip to main content

Compute and plot calibration-sharpness diagrams.

Project description

sharpcal: Augmenting Calibration With Sharpness

sharpcal is a simple (work-in-progress) library for making calibration-sharpness plots based on: https://arxiv.org/abs/2406.04068. A quick guide on how to read the above plot:

  • The X-axis is model confidence, and the Y-axis is model accuracy conditional on confidence; the dashed line corresponds to a perfectly calibrated model.
  • The curve in red represents a kernel regression estimate of the conditional expectation of model accuracy conditioned on model confidence.
  • The area around the red line corresponds to a "sharpness gap", which can intuitively be thought of as leftover (pointwise) generalization error after accounting for (pointwise) calibration error. Once again, see https://arxiv.org/abs/2406.04068 for more details.
  • The boxed errors shown at the top left of the plot display the calibration error (CAL) and total generalization error (TOT) respectively. In this case, the generalization metric is Brier score. The +/- shown next to the calibration error is a result of using subsampling to improve the computational efficiency of the kernel regression estimates, and corresponds to 1 standard deviation of the subsampled calibration errors (the displayed error is the mean).

Installation

Installation is available via pip (or directly from source).

pip install sharpcal

Getting Started

The following example demonstrates the main utilities of sharpcal on some synthetic data.

import torch
from sharpcal.calibration import SharpCal
from sharpcal.kernels import Gaussian1D
from sharpcal.scores import BrierScore

kernel = Gaussian1D(bandwidth=0.05)
score = BrierScore()
sc = SharpCal(kernel=kernel, score=score, n_points=1000, device="cpu")

fake_probs, fake_labels = torch.rand(100, 1), torch.rand(100, 1).round().long()
sc.plot_cal_curve(fake_probs, fake_labels, fname=None)

The same example works with no changes even if we replace fake_probs and fake_labels above with multi-class versions.

fake_probs = torch.nn.functional.softmax(torch.rand(100, 10), dim=1)
fake_labels = fake_probs.argmax(dim=1).unsqueeze(dim=1)

In this case, sharpcal automatically converts the multi-class calibration problem to the binary confidence calibration problem. This is done by replacing the multi-class labels with 0-1 correctness labels, and replacing the full softmax probabilities with the max predicted probability.

Recreating Paper Experiments

A more advanced example that shows how to recreate the experiments in the paper can be found at examples/recal_comparison.py. This example also illustrates recalibration utilities available under sharpcal.recal.

Citation

@misc{chidambaram2024reassessing,
      title={Reassessing How to Compare and Improve the Calibration of Machine Learning Models}, 
      author={Muthu Chidambaram and Rong Ge},
      year={2024},
      eprint={2406.04068},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sharpcal-0.0.2.tar.gz (10.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sharpcal-0.0.2-py3-none-any.whl (10.8 kB view details)

Uploaded Python 3

File details

Details for the file sharpcal-0.0.2.tar.gz.

File metadata

  • Download URL: sharpcal-0.0.2.tar.gz
  • Upload date:
  • Size: 10.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.10

File hashes

Hashes for sharpcal-0.0.2.tar.gz
Algorithm Hash digest
SHA256 5c604fdf6d551e458c63d77e3c7dcbd8f5985f97527a17a882ab7a6c1b0dddc8
MD5 a7f3a0e533baf3de920585df71c68621
BLAKE2b-256 3d86859ece5d12fa41c7de6f6246aa441218116a6e929d612a20de1622027424

See more details on using hashes here.

File details

Details for the file sharpcal-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: sharpcal-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 10.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.10

File hashes

Hashes for sharpcal-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a623a9b2c4975b38616bcf9278765c294be591e675af6e9990c1d906b4ff8afb
MD5 a1e53081af93bd617db9eccf9b30df54
BLAKE2b-256 60e41705efb1574f6c35a46b1041ab13f1133fe38679a0cfd69d2a0db5710d03

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page