Skip to main content

Selective prediction based on uncertainty scores

Project description

seapig

PyPI Conda Version Codecov test MIT DOI


seapig provides uncertainty-based selective inference for deep learning models. Its main focus currently lies on analyzing latent representations. The library implements a small set of lightweight, composable uncertainty scores that are used to decide whether to accept or reject an individual query sample at prediction time. Thresholds are calibrated on an independent validation set. It provides a wrapper for torchmetrics that allows evaluating the performance of a selective inference system on a test set, and a PyTorch Lightning task for seamless integration into training and evaluation pipelines.

Installation

seapig is available on PyPI and can be installed with pip. For the latest features, install from the GitHub repository. We recommend using a virtual environment to avoid dependency conflicts. The package has a small set of core dependencies, and optional extras for suggested features, development, and documentation. See the installation instructions below.

# minimal installation
pip install seapig
# including optional features
pip install seapig[suggested]

# for contributors / developers (tests, linters, type-checkers)
pip install seapig[dev]
# for building documentation (quarto-cli, great-docs, others)
pip install seapig[docs]
# or, if you need everything:
pip install seapig[all]

seapig is also available on the conda-forge channel and can be installed with conda:

conda install -c conda-forge seapig

or with mamba:

mamba install -c conda-forge seapig

Why selective prediction?

  • Machine learning models often fail silently on out-of-distribution inputs.
  • Selective prediction lets a system abstain from predicting when the input is considered unreliable.
  • seapig uses internal model representations (embeddings) to detect such atypical inputs with interpretable, fast-to-compute scores.

The core idea is to compute a representation for each input, score how similar the representation is to training representations, and reject inputs whose score indicates low support.

From uncertainty scores to selective inference

All uncertainty scores produce a scalar score $s(x)$ for each query $x$. Given a threshold $\lambda$, we derive a binary selection function indicating which samples to accept. For example, accepting samples with score below $\lambda$:

$$g_{\lambda}(x) = \mathbf{1}{s(x) \le \lambda}.$$

We recommend calibrating $\lambda$ on an independent calibration set to fix a desired coverage level (fraction of accepted samples) and compute the correspoding empirical quantile $q$ of the calibration scores:

$$\lambda_{q} = Q_q(s_1^{cal}, s_2^{cal}, \dots, s_m^{cal}),$$

where $s_i^{cal}$ are the scores of the calibration samples. The decision function $g_{\lambda}(x)$ can then be applied at inference time to accept or reject predictions. We obtain a selective predictor, $h(x)$, that either produces an output or abstains from prediction, depending on the score of the input:

$$h(x) = \begin{cases} f(x), & \text{if} g(x)=1,\ \varnothing, & \text{if } g(x)=0. \end{cases}$$

How to use seapig

The code snippets show a typical use of distance-based scores: (1) compute or provide embeddings, (2) fit a confidence score, (3) calibrate a threshold on validation data, and (4) accept/reject predictions at inference time. These are illustrative examples that are intentionally minimal so the flow is clear. For a more complete example, see the “Getting Started” tutorial in the documentation.

Precomputed embeddings

If you have precomputed embeddings for your reference, validation, and query sets, you can fit a score directly on the tensors. The example below uses random tensors to illustrate the API.

import torch
from seapig.scores import EuclideanScore
from seapig.utils.progress import disable
disable()  # disables  seapig progress bars for quickstart example
torch.manual_seed(0) 
# latent representations a torch.Tensor of shapes (N, D), (M, D), (Q, D)
ref_emb, val_emb, query_emb = torch.randn(1000, 32), torch.randn(200, 32), torch.randn(10, 32)

score = EuclideanScore(k=5, stat="mean")
score.fit(X=ref_emb, Y=val_emb)
score.set_threshold(q=0.90)   # keep ~90% coverage on validation set
sel = score.select(query_emb)
print(sel)
{'score': tensor([6.2651, 5.5944, 6.0226, 5.8903, 6.2928, 4.8388, 5.7290, 5.3641, 5.6599,
        5.9143]), 'selected': tensor([ True,  True,  True,  True, False,  True,  True,  True,  True,  True])}

On-the-fly embedding extraction

If you have a model that can compute embeddings on the fly, you can fit a score with the model and loaders API. This requires the model to expose an .embed() method. The example below uses a dummy model and random data to illustrate the API.

from torch.utils.data import TensorDataset, DataLoader
ds_train = TensorDataset(torch.randn(1000, 32), torch.randint(0, 2, (1000,)))
ds_val = TensorDataset(torch.randn(200, 32), torch.randint(0, 2, (200,)))
ds_test = TensorDataset(torch.randn(10, 32), torch.randint(0, 2, (10,))) 
train_loader = DataLoader(ds_train, batch_size=64)
val_loader = DataLoader(ds_val, batch_size=64)
test_loader = DataLoader(ds_test, batch_size=64)

# model exposes .embed(x) -> (B, D)
class Model(torch.nn.Module):
    def embed(self, x):
        image = x[0]
        label = x[1]
        return torch.randn(image.shape[0], 32) 

model = Model()

score = EuclideanScore(k=3)
score.fit(model=model, loaders={"train": train_loader, "val": val_loader})
score.set_threshold(q=0.80) # keep ~80% coverage on validation set

sel = score.select(model=model, loader=test_loader)
print(sel)
{'score': tensor([6.4586, 5.5724, 5.6794, 5.7046, 5.0609, 5.8174, 5.5684, 5.3449, 5.4205,
        5.6091, 5.5898, 6.2813, 6.1693, 6.3420, 6.3664, 5.5906, 4.6899, 5.6637,
        5.7695, 5.1600, 5.2580, 5.1575, 5.9254, 6.0015, 6.5361, 5.4042, 5.6627,
        5.7872, 5.4679, 6.0055, 6.1751, 5.7445]), 'selected': tensor([False,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True, False, False, False, False,  True,  True,  True,  True,  True,
         True,  True,  True,  True, False,  True,  True,  True,  True,  True,
        False,  True])}

Using SelectiveInferenceTask with a lightning module

When working with lightning modules, you can wrap the model and score into a SelectiveInferenceTask for evaluation and prediction purposes. This allows you to seamlessly integrate selective inference into your training and evaluation pipelines, and compute metrics for the full, selected, and rejected samples. The example below uses a dummy model and random data to illustrate the API.

from seapig import SelectiveInferenceTask
from lightning import Trainer, LightningModule
from torchmetrics import Accuracy

# minimal LightningModule 
class Model(LightningModule):
    def __init__(self):
        super().__init__()
        self.test_metrics = Accuracy("binary")
    def forward(self, x):
        pred = torch.randint(0, 2, (x.shape[0],)) 
        return pred
    def embed(self, x):
        return torch.randn(x.shape[0], 32) 
    def test_step(self, batch, batch_idx, dataloader_idx=0):
        image = batch[0]
        label = batch[1]
        pred = self.forward(image)
        print(pred.shape, label.shape)
        self.test_metrics.update(pred, label)
        self.log_dict(self.test_metrics.compute(), sync_dist=True)
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        image = batch[0]
        pred = self.forward(image)
        return pred

trainer = Trainer(accelerator="cpu")
model = Model()
# trainer.fit(...) and score.fit(...) are expected to have been called already
sel_task = SelectiveInferenceTask(task=model, score=score)
# evaluate on test set, will return metrics for the full, selected, and rejected samples
metrics = trainer.test(sel_task, dataloaders=test_loader)
# or for prediction, will return a dict with keys "predictions", "selected", and "score" for each sample
preds = trainer.predict(sel_task, dataloaders=test_loader)
print(preds)
Output()
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│    full/BinaryAccuracy        0.4000000059604645     │
│  rejected/BinaryAccuracy              0.0            │
│  selected/BinaryAccuracy      0.5714285969734192     │
└───────────────────────────┴───────────────────────────┘

Output()

[{'predictions': tensor([0, 0, 1, 1, 1, 1, 0, 0, 0, 1]), 'score': tensor([5.7045, 6.2594, 6.5007, 5.1679, 5.6751, 6.0544, 5.6146, 6.6772, 6.1766,
        5.4252]), 'selected': tensor([ True, False, False,  True,  True,  True,  True, False, False,  True])}]

Available scores

  • KNN-distances: EuclideanScore, CosineScore, MahalanobisScore
  • Logit-based scores: EnergyScore, EntropyScore, LogitScore, MarginScore, SoftmaxScore
  • PCA-based reconstruction: PCAScore
  • PyOD detectors: PyODScore
  • Random baseline: RandomScore

Further reading

Code of Conduct

This project adheres to the Contributor Covenant Code of Conduct. By participating, you are expected to uphold this code.

License

Funding

This research was funded in the course of TRR 391 Spatio-temporal Statistics for the Transition of Energy and Transport (520388526) by the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

seapig-0.2.0.tar.gz (89.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

seapig-0.2.0-py3-none-any.whl (53.9 kB view details)

Uploaded Python 3

File details

Details for the file seapig-0.2.0.tar.gz.

File metadata

  • Download URL: seapig-0.2.0.tar.gz
  • Upload date:
  • Size: 89.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for seapig-0.2.0.tar.gz
Algorithm Hash digest
SHA256 99719a0bdf340c244d67017fddaa9c3832bd9a0c03e1f17930fc463d9b239927
MD5 b6d7af053853af33633c1e3631d648bd
BLAKE2b-256 204fcdf217cf0676af3624e03283c763b9251a69824bc7be07f590dd7a466782

See more details on using hashes here.

Provenance

The following attestation bundles were made for seapig-0.2.0.tar.gz:

Publisher: publish.yml on goergen95/seapig

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file seapig-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: seapig-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 53.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for seapig-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c24d7ecc90d09d69017292364b1c567b541e9af4be4c996e13c8a7fbd7aad396
MD5 2f2a20da41d226ff8330b27985262c12
BLAKE2b-256 215aec03a00286b6e2f6cb3a1446572341cc83040c41392d460eb50cb8c592c9

See more details on using hashes here.

Provenance

The following attestation bundles were made for seapig-0.2.0-py3-none-any.whl:

Publisher: publish.yml on goergen95/seapig

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page