Skip to main content

Post-hoc uncertainty quantification toolkit for PyTorch deep learning models

Project description

FailCatcher

FailCatcher is a uncertainty quantification (UQ) toolkit for PyTorch classification deep learning models, developed and benchmarked on medical imaging datasets from the MedMNIST collection and external test sets. Failure detection benchmark results can be found in the benchmarks README.

Preprint: Steinmetz et al., medRxiv 2026 — DOI: 10.64898/2026.05.04.26350496

pip install FailCatcher

The project provides:

  • A reusable Python library (ToolBox/) implementing uncertainty quantification methods for pytorch classification models.
  • A full benchmarking pipeline (Benchmarks/), including model training, classification evaluation and failure detection evaluation on a diverse set of distribution shifts images.

Repository structure

FailCatcher/
├── ToolBox/                        # UQ library (installable Python package)
│   ├── failure_detection.py        # High-level FailureDetector API
│   ├── UQ_toolbox.py               # Public API aggregator
│   ├── core/                       # Base classes and shared inference utilities
│   ├── methods/                    # UQ method implementations
│   │   ├── tta.py                  # Test-Time Augmentation (TTA) and GPS
│   │   ├── ensemble.py             # Ensemble STD and MC Dropout
│   │   ├── distance.py             # MSR, MLS, and calibration methods
│   │   └── latent.py               # KNN and SHAP latent-space methods
│   ├── search/                     # Greedy Policy Search (GPS) algorithm
│   ├── evaluation/                 # AUROC, AURC, AUGRC metrics and plots
│   ├── visualization/              # Visualization utilities
│   └── tests/                      # Smoke tests and pre-run checks
│
├── Benchmarks/
│   └── medMNIST/
│       ├── launcher_benchmark.py   # Top-level benchmark launcher
│       ├── run_medmnist_benchmark.py # Core benchmark runner (single config)
│       ├── trainings/              # Model training scripts and launchers
│       ├── utils/                  # Data loading, preprocessing, visualization
│       │   ├── train_models_load_datasets.py  # Central data/model utilities
│       │   ├── dataset_utils.py               # External datasets, corruptions
│       │   └── data_preprocessing_classification_evaluation/
│       ├── data/                   # External test datasets (AMOS-2022, MIDOG++, ISIC)
│       ├── models/                 # Trained model checkpoints
│       ├── runs/                   # Training logs and per-run artifacts
│       └── results/                # Benchmark outputs (JSON, figures, cache)
│
└── requirements.txt

Quick start

1. Install

pip install FailCatcher

For development (editable install from source):

git clone https://github.com/pstnmz/FailCatcher && pip install -e FailCatcher/

2. Quick-start tutorial

See tutorial.ipynb for a self-contained end-to-end example on CIFAR-10:
download a model from HuggingFace → inference → MSR uncertainty → AUROC-f / AURC / AUGRC → plots.

3. Download models and datasets from HuggingFace

Pre-trained model checkpoints and pre-processed external datasets are available on HuggingFace. Run the one-command setup to skip training and manual preprocessing:

python scripts/setup_from_hub.py

You can download models or datasets independently:

python scripts/setup_from_hub.py --models-only
python scripts/setup_from_hub.py --datasets-only --datasets amos22 midog dermamnist-e

4. Train models (alternative to step 3)

See Benchmarks/README.md for the full reproducible training and benchmarking pipeline.

5. Run the benchmark

python Benchmarks/medMNIST/launcher_benchmark.py \
    --python /path/to/your/venv/bin/python \
    --datasets breastmnist organamnist \
    --models resnet18 \
    --setups "" DA \
    --gpu 0

UQ methods

Method Description
MSR Maximum Softmax Response — distance between predicted probability and 1
MSR-calibrated MSR after temperature / Platt scaling calibration
MLS Maximum Logit Score — pre-softmax equivalent of MSR
Ensembling Standard deviation across 5-fold CV model predictions
TTA Test-Time Augmentation — std over random augmentation passes
GPS Greedy Policy Search — optimised TTA policy found on the calibration set
KNN-Raw k-NN distance in avgpool latent space
KNN-SHAP KNN with SHAP-weighted latent features
MC Dropout Monte Carlo Dropout at inference time
ZScore Aggregation Z-score normalised aggregation of multiple methods

Datasets

Internal test sets (MedMNIST)

breastmnist, pneumoniamnist, organamnist, octmnist, pathmnist, bloodmnist, tissuemnist, dermamnist-e

External test sets (not in git — see Benchmarks/README.md for setup)

  • AMOS-2022 — abdominal CT organ patches mapped to OrganaMNIST classes. Available on HuggingFace or preprocessed via data/AMOS_2022/read_npz.ipynb.
  • MIDOG++ — mitosis detection histology patches as OOD test for PathMNIST. Available on HuggingFace or generated by utils/data_preprocessing_classification_evaluation/create_midog_patch_dataset.py.
  • DermaMNIST-E — extended DermaMNIST with ID and external centre splits. Available on HuggingFace or downloaded from Zenodo, loaded by utils/data_preprocessing_classification_evaluation/local_dermamnist_e.py.

Reproducibility

All benchmark results are reproducible from scratch:

  • Random seeds fixed to 42 everywhere (training, CV splits, TTA).
  • 5-fold StratifiedKFold CV with seed=42 is consistent between training and inference.
  • Model checkpoints, result JSONs, and caches are saved with configuration-specific suffixes.
  • See Benchmarks/README.md for step-by-step instructions.

Python version and environment

Tested with Python 3.12 and the following key packages:

Package Version
torch 2.6.0
torchvision 0.21.0
numpy 2.1.3
scikit-learn 1.6.1
monai 1.5.1
medmnist 3.0.1
shap 0.46.0
matplotlib 3.10.0
seaborn 0.13.2

License

This project is licensed under CC BY-NC-SA 4.0.

The code is intended for research and academic use only. Commercial use is prohibited.

For commercial use, please contact the author.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

failcatcher-2.0.1.tar.gz (85.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

failcatcher-2.0.1-py3-none-any.whl (89.7 kB view details)

Uploaded Python 3

File details

Details for the file failcatcher-2.0.1.tar.gz.

File metadata

  • Download URL: failcatcher-2.0.1.tar.gz
  • Upload date:
  • Size: 85.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for failcatcher-2.0.1.tar.gz
Algorithm Hash digest
SHA256 5994dc70f10f857d6d0d1819a4f3a68143ee24274c4da7073c23ab63522b79b3
MD5 25b6d79d282773c9732fea370360ecb0
BLAKE2b-256 3b52671f085f4c1f0efd6b91dd17399fdb368d33249915ef8fc3360cdfbcc7c8

See more details on using hashes here.

File details

Details for the file failcatcher-2.0.1-py3-none-any.whl.

File metadata

  • Download URL: failcatcher-2.0.1-py3-none-any.whl
  • Upload date:
  • Size: 89.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for failcatcher-2.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 bb6aa9050dede60362ebd8471ad5b6bea315ea782f69bcf4d5640a8b56647a85
MD5 87c184de6b1a6bb2a5ad52742e071be1
BLAKE2b-256 c28bae845324a917b757107f72b6994c41c30ffbce83c63f08d0efe0b0d4436c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page