Calculate common OOD detection metrics
Project description
OOD Detection Metrics
Functions for computing metrics commonly used in the field of out-of-distribution (OOD) detection.
Installation
pip install ood-metrics
Metrics functions
AUROC
Calculate and return the area under the ROC curve using unthresholded predictions on the data and a binary true label.
from ood_metrics import auroc
labels = [0, 0, 0, 1, 0]
scores = [0.1, 0.3, 0.6, 0.9, 1.3]
print(auroc(scores, labels))
# 0.75
AUPR
Calculate and return the area under the Precision Recall curve using unthresholded predictions on the data and a binary true label.
from ood_metrics import aupr
labels = [0, 0, 0, 1, 0]
scores = [0.1, 0.3, 0.6, 0.9, 1.3]
print(aupr(scores, labels))
# 0.25
FPR @ 95% TPR
Return the FPR when TPR is at least 95%.
from ood_metrics import fpr_at_95_tpr
labels = [0, 0, 0, 1, 0]
scores = [0.1, 0.3, 0.6, 0.9, 1.3]
print(fpr_at_95_tpr(scores, labels))
# 0.25
Detection Error
Return the misclassification probability when TPR is 95%.
from ood_metrics import detection_error
labels = [0, 0, 0, 1, 0]
scores = [0.1, 0.3, 0.6, 0.9, 1.3]
print(detection_error(scores, labels))
# 0.125
Calculate all stats
Using predictions and labels, return a dictionary containing all novelty detection performance statistics.
from ood_metrics import calc_metrics
labels = [0, 0, 0, 1, 0]
scores = [0.1, 0.3, 0.6, 0.9, 1.3]
print(calc_metrics(scores, labels))
# {
# 'fpr_at_95_tpr': 0.25,
# 'detection_error': 0.125,
# 'auroc': 0.75,
# 'aupr_in': 0.25,
# 'aupr_out': 0.94375
# }
Plotting functions
Plot ROC
Plot an ROC curve based on unthresholded predictions and true binary labels.
from ood_metrics import plot_roc
labels = [0, 0, 0, 1, 0]
scores = [0.1, 0.3, 0.6, 0.9, 1.3]
plot_roc(scores, labels)
# Generate Matplotlib AUROC plot
Plot PR
Plot an Precision-Recall curve based on unthresholded predictions and true binary labels.
from ood_metrics import plot_pr
labels = [0, 0, 0, 1, 0]
scores = [0.1, 0.3, 0.6, 0.9, 1.3]
plot_pr(scores, labels)
# Generate Matplotlib Precision-Recall plot
Plot Barcode
Plot a visualization showing inliers and outliers sorted by their prediction of novelty.
from ood_metrics import plot_barcode
labels = [0, 0, 0, 1, 0]
scores = [0.1, 0.3, 0.6, 0.9, 1.3]
plot_barcode(scores, labels)
# Shows visualization of sort order of labels occording to the scores.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for ood_metrics-0.3.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ce2da607abd8352aac078ec30b4ef0b77895a58bf7aeef83a8f03d78ec2fb6c2 |
|
MD5 | cfba744f1e21f00a546b5bc34e2cc9f5 |
|
BLAKE2b-256 | eca55d13678ae31dfb6b02accf2a7398111fc55d0f73539f591b75d07acdae06 |