Compute and plot reliability diagrams based on calibration distance.
Project description
relplot: Principled Reliability Diagrams
relplot is a Python package for plotting reliability diagrams and measuring calibration error,
in a theoretically-principled way.
The package generates reliability diagrams as shown on the right:
How to Read the Diagram
-
The input data is a set of observations: pairs of predicted probability and true outcomes $(f_i, y_i) \in [0, 1] \times {0, 1}$. For example, $f_i$ may be the forecasted "chance of rain" on day $i$, and $y_i$ the indicator of whether it rained or not on day $i$.
-
The x-axis shows the predicted probabilities, and the y-axis shows an estimate of the true probability, conditioned on the predicted probability. Formally, this is a regression of outcomes $y$ on predictions $f$.
-
The tick marks show the raw data: namely, the predicted probabilities for up to 100 datapoints, plotted above or below the x-axis according to whether the true outcome was 1 or 0. The thickness of the red regression curve represents the smoothed density of these tick marks, while the height of the curve represents the smoothed fraction whose true outcome is 1.
-
The SmoothECE (smECE) is a measure of mis-calibration: it is essentially the average absolute difference between the red regression curve and the diagonal, averaged over x-coordinates that are distributed as the tick marks are (i.e. integrated over the density of predictions). See the paper for full details of the estimator and its properties.
-
The smECE is reported with $\pm$ denoting 95% confidence intervals, estimated via bootstrapping. The gray band similarly shows 95% bootstrapped confidence bands around the regression line.
Formally, the reliability diagram is obtained by kernel smoothing with a careful choice of parameters. The choice of smoothing bandwidth (akin to "bin width") is cruicial, but is done automatically by the code in a theorhetically-justified way.
This package is based on the theoretical results in the paper Smooth ECE: Principled Reliability Diagrams via Kernel Smoothing (ICLR 2024).
Installation
Install with Pip:
> pip install relplot
Or, clone the repo and install with:
> cd relplot
> pip install .
Getting Started
Basic usage:
import relplot as rp
# ...
# f: array of probabilities [f_i]
# y: array of binary labels [y_i]
calib_error = rp.smECE(f, y) # compute calibration error (scalar)
fig, ax = rp.rel_diagram(f, y) # plot
See a quick demo in notebooks/demo.ipynb.
For more control, one can compute the calibration data with relplot.prepare_rel_diagram, and then plot it later with relplot.plot_rel_diagram.
For example:
...
diagram = rp.prepare_rel_diagram(f, y) # compute calibration data (dictionary)
print('calibration error:', diagram['ce'])
plt.plot(diagram['mesh'], diagram['mu']) # plot the calibration curve manually
fig, ax = rp.plot_rel_diagram(diagram) # plot the diagram in a new figure
The smoothed regression function itself is returned as diagram['mu'],
which specifies values on the grid of x-coordinates in diagram['mesh'].
This can be used for manual re-calibration.
Manual Bandwidth
To measure SmoothECE with a manual choice of bandwidth (rather than automatic choice), use:
calib_error = relplot.smECE_sigma(f, y, sigma=0.05)
Using a manual bandwidth can sometimes be desirable for interpretability: smECE_sigma behaves similarly to binnedECE with bin_width=sigma, at the cost of slightly weaker theoretical guarantees.
Data Format
Methods expect inputs in the form of a 1D array of predicted probabilities (f) and a 1D array of binary labels (y), where $f_i \in [0, 1]$ and $y_i \in {0, 1}$. We then consider the calibration of the distribution $(f_i, y_i)$ of prediction-outcome pairs. This package primarily considers the binary outcome setting, but can be used to measure multi-class confidence calibration as shown below.
Multi-class Confidence Calibration
In the multi-class setting, confidence calibration can be measured by expressing it as the binary calibration of the distribution on (confidence, accuracy) pairs. A convenience function for this common use case is provided:
# f: [N, C] array of logits over C classes
# y: [N, 1] array of predicted classes
conf, acc = relplot.multiclass_logits_to_confidences(f, y) # reduce to binary setting
relplot.rel_diagram(f=conf, y=acc) # plot confidence calibration diagram
relplot.smECE(f=conf, y=acc) # compute smECE of confidence calibration
Customization and Usage Tips
The plot made by relplot.rel_diagram can be customized in various ways, as shown below.
See this notebook for examples of more options: notebooks/figure1.ipynb
- For small datasets, you may want to disable bootstrapping (which subsamples the data). Pass the parameter
plot_confidence_band=False. - To override the automatic choice of kernel bandwidth for the diagram, set the parameter
kde_bandwidth.
Additional Notebooks and Features
- The header image (Figure 1 of the paper) is generated in notebooks/figure1.ipynb
- The experiments in the paper are reproduced in notebooks/paper_experiments.ipynb
relplot.metricscontains implementations of various alternate calibration measures, including binnedECE and laplace kernel calibration. This is in addition to the recommended calibration measure of smoothECE (relplot.smECE).relplot.rel_diagram_binnedplots the "binned" reliability diagram. Not recommended for usage; included for comparison.relplot.config.use_tex_fontscan be set to True if you have $\LaTeX$ installed.
Citation
If you use relplot in your work, please consider citing:
@inproceedings{blasiok2024smooth,
title={Smooth {ECE}: Principled Reliability Diagrams via Kernel Smoothing},
author={B{\l}asiok, Jaros{\l}aw and Nakkiran, Preetum},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=XwiA1nDahv}
}
Acknowledgements
We thank Jason Eisner and Adam Goliński for helpful suggestions on the package and documentation.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file relplot-1.0.3.tar.gz.
File metadata
- Download URL: relplot-1.0.3.tar.gz
- Upload date:
- Size: 19.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
838f3681ffcc9ffcdc3521a162fed6908830075e5ba4b21d66a8c1db2bdd4999
|
|
| MD5 |
661dbbd2b979c1b3672c8005496b0ee2
|
|
| BLAKE2b-256 |
dfdf9becaad4f326a8888b7aaca66cc4383e82ca198be614b12ba0c3f2e41190
|
File details
Details for the file relplot-1.0.3-py3-none-any.whl.
File metadata
- Download URL: relplot-1.0.3-py3-none-any.whl
- Upload date:
- Size: 17.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a807187018207d7215311271d81ec10161f27338b3555f6cfbe1b33a32e2f35e
|
|
| MD5 |
92af23363c1be2dd95fdfec33c8aa77b
|
|
| BLAKE2b-256 |
75122bacb856a27293be5ab420db8ffb52e9fff44f1cc6869d1fd1f78267e132
|