Skip to main content

Post-hoc uncertainty quantification toolkit for PyTorch deep learning models

Project description

FailCatcher

FailCatcher is a uncertainty quantification (UQ) toolkit for PyTorch classification deep learning models, developed and benchmarked on medical imaging datasets from the MedMNIST collection and external test sets. Failure detection benchmark results can be found in the benchmarks README.

Preprint: Steinmetz et al., medRxiv 2026 — DOI: 10.64898/2026.05.04.26350496

The project provides:

  • A reusable Python library (ToolBox/) implementing multiple UQ methods with a clean, unified API.
  • A full benchmarking pipeline (Benchmarks/), including model training, classification evaluation and failure detection evaluation on a diverse set of distribution shifts images.

Repository structure

FailCatcher/
├── ToolBox/                        # UQ library (installable Python package)
│   ├── failure_detection.py        # High-level FailureDetector API
│   ├── UQ_toolbox.py               # Public API aggregator
│   ├── core/                       # Base classes and shared inference utilities
│   ├── methods/                    # UQ method implementations
│   │   ├── tta.py                  # Test-Time Augmentation (TTA) and GPS
│   │   ├── ensemble.py             # Ensemble STD and MC Dropout
│   │   ├── distance.py             # MSR, MLS, and calibration methods
│   │   └── latent.py               # KNN and SHAP latent-space methods
│   ├── search/                     # Greedy Policy Search (GPS) algorithm
│   ├── evaluation/                 # AUROC, AURC, AUGRC metrics and plots
│   ├── visualization/              # Visualization utilities
│   └── tests/                      # Smoke tests and pre-run checks
│
├── Benchmarks/
│   └── medMNIST/
│       ├── launcher_benchmark.py   # Top-level benchmark launcher
│       ├── run_medmnist_benchmark.py # Core benchmark runner (single config)
│       ├── trainings/              # Model training scripts and launchers
│       ├── utils/                  # Data loading, preprocessing, visualization
│       │   ├── train_models_load_datasets.py  # Central data/model utilities
│       │   ├── dataset_utils.py               # External datasets, corruptions
│       │   └── data_preprocessing_classification_evaluation/
│       ├── data/                   # External test datasets (AMOS-2022, MIDOG++, ISIC)
│       ├── models/                 # Trained model checkpoints
│       ├── runs/                   # Training logs and per-run artifacts
│       └── results/                # Benchmark outputs (JSON, figures, cache)
│
└── requirements.txt

Quick start

1. Install dependencies

pip install -r requirements.txt

For covariate-shift corruption benchmarks, also install:

pip install medmnistc

2. Install the ToolBox library

cd ToolBox
pip install -e .

3. Download models and datasets from HuggingFace

Pre-trained model checkpoints and pre-processed external datasets are available on HuggingFace. Run the one-command setup to skip training and manual preprocessing:

python scripts/setup_from_hub.py

You can download models or datasets independently:

python scripts/setup_from_hub.py --models-only
python scripts/setup_from_hub.py --datasets-only --datasets amos22 midog dermamnist-e

4. Train models (alternative to step 3)

See Benchmarks/README.md for the full reproducible training and benchmarking pipeline.

5. Run the benchmark

python Benchmarks/medMNIST/launcher_benchmark.py \
    --python /path/to/your/venv/bin/python \
    --datasets breastmnist organamnist \
    --models resnet18 \
    --setups "" DA \
    --gpu 0

UQ methods

Method Description
MSR Maximum Softmax Response — distance between predicted probability and 1
MSR-calibrated MSR after temperature / Platt scaling calibration
MLS Maximum Logit Score — pre-softmax equivalent of MSR
Ensembling Standard deviation across 5-fold CV model predictions
TTA Test-Time Augmentation — std over random augmentation passes
GPS Greedy Policy Search — optimised TTA policy found on the calibration set
KNN-Raw k-NN distance in avgpool latent space
KNN-SHAP KNN with SHAP-weighted latent features
MC Dropout Monte Carlo Dropout at inference time
ZScore Aggregation Z-score normalised aggregation of multiple methods

Datasets

Internal test sets (MedMNIST)

breastmnist, pneumoniamnist, organamnist, octmnist, pathmnist, bloodmnist, tissuemnist, dermamnist-e

External test sets (not in git — see Benchmarks/README.md for setup)

  • AMOS-2022 — abdominal CT organ patches mapped to OrganaMNIST classes. Available on HuggingFace or preprocessed via data/AMOS_2022/read_npz.ipynb.
  • MIDOG++ — mitosis detection histology patches as OOD test for PathMNIST. Available on HuggingFace or generated by utils/data_preprocessing_classification_evaluation/create_midog_patch_dataset.py.
  • DermaMNIST-E — extended DermaMNIST with ID and external centre splits. Available on HuggingFace or downloaded from Zenodo, loaded by utils/data_preprocessing_classification_evaluation/local_dermamnist_e.py.

Reproducibility

All benchmark results are reproducible from scratch:

  • Random seeds fixed to 42 everywhere (training, CV splits, TTA).
  • 5-fold StratifiedKFold CV with seed=42 is consistent between training and inference.
  • Model checkpoints, result JSONs, and caches are saved with configuration-specific suffixes.
  • See Benchmarks/README.md for step-by-step instructions.

Python version and environment

Tested with Python 3.12 and the following key packages:

Package Version
torch 2.6.0
torchvision 0.21.0
numpy 2.1.3
scikit-learn 1.6.1
monai 1.5.1
medmnist 3.0.1
shap 0.46.0
matplotlib 3.10.0
seaborn 0.13.2

License

This project is licensed under CC BY-NC-SA 4.0.

The code is intended for research and academic use only. Commercial use is prohibited.

For commercial use, please contact the author.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

failcatcher-2.0.0.tar.gz (73.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

failcatcher-2.0.0-py3-none-any.whl (77.2 kB view details)

Uploaded Python 3

File details

Details for the file failcatcher-2.0.0.tar.gz.

File metadata

  • Download URL: failcatcher-2.0.0.tar.gz
  • Upload date:
  • Size: 73.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for failcatcher-2.0.0.tar.gz
Algorithm Hash digest
SHA256 e415051a12fa2dca1639cd8953e08bf70a0a8a1c0b0888ea613f68a3d8f087bc
MD5 4a19355a7adfa7cb5a2e6b73d8c3e488
BLAKE2b-256 ef00d6be19669f4d14faa80b7986129e5ecb9b7d9c222c6dc6d6265efd3ff8f2

See more details on using hashes here.

File details

Details for the file failcatcher-2.0.0-py3-none-any.whl.

File metadata

  • Download URL: failcatcher-2.0.0-py3-none-any.whl
  • Upload date:
  • Size: 77.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for failcatcher-2.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 11e3ab3fa92c7682097e69c00458dd37275b51bbecdf5270313df1e1c9a8b9b7
MD5 7b5a7bb1a624ed5c1f2a4dae6f6959c5
BLAKE2b-256 05ff386da89459cc78bd61ec3958e39db111f13bfaa612ebeda5a71034278b70

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page