Diverse Prototypical Ensembles Improve Robustness to Subpopulation Shift
Project description
Diverse Prototypical Ensembles Improve Robustness to Subpopulation Shift
This repository provides the official implementation and experiments for our ICML 2025 paper:
Diverse Prototypical Ensembles Improve Robustness to Subpopulation Shift
by Minh To, Paul F. R. Wilson, and co-authors.
Overview
Machine learning models often experience significant performance degradation when deployed under distribution shifts. A particularly important and challenging case is subpopulation shift, where the proportions of subgroups vary between training and deployment. Subpopulation shifts occurred in many forms, including spurious correlations, attribute or class imbalance, and previously unseen attribute combinations at test time, can lead to large disparities in model performance across subgroups.
Existing approaches typically modify empirical risk minimization (ERM) using reweighting or group-aware strategies. However, these often rely on prior knowledge of subgroup structure or annotated group membership, which may not be available in practice.
We propose Diverse Prototypical Ensembles (DPE), a simple and scalable framework that improves model robustness to subpopulation shifts without requiring group annotations. DPE replaces the standard linear classification head with an ensemble of prototype-based classifiers, each trained on a different balanced subset of data. Diversity is promoted through an inter-prototype similarity loss, encouraging each classifier to attend to different regions of the feature space.
🚀 Installation
First, make sure you have an up-to-date packaging environment:
python3 -m pip install --upgrade pip setuptools wheel
Then install dpe directly from PyPI:
pip install dpe
🎯 Quick Demo
from dpe import DPE
def main():
dpe = DPE(
data_dir='/path/to/pre-extracted-features/folder',
metadata_path='/path/to/metadata.csv',
num_stages=2,
device='cuda',
eval_freq=1,
train_attr='no',
seed=0,
)
dpe.fit()
print("Demo completed successfully!")
if __name__ == '__main__':
main()
Note: The structure of /path/to/pre-extracted-features/folder must include the following files:
feats_val.npyfeats_test.npy
👉 For a full list of configurable options, refer to the Args class inside src/dpe/core.py.
👉 A step-by-step demonstration is available in notebooks/03_demo.ipynb.
Notebooks
We provide a collection of Jupyter notebooks under the notebooks/ directory to illustrate key components
of Diverse Prototypical Ensembles (DPE) through visualization, controlled experiments, and ablation studies. These
notebooks provide a walkthrough of the motivation and implementation of our method as described in the paper,
demonstrated on two standard benchmark datasets.
-
00_synthetic.ipynb
A 2D synthetic experiment that simulates subpopulation shift under controlled conditions.
This notebook visualizes the limitations of standard classifiers trained on imbalanced subgroups and demonstrates how DPE achieves better coverage and robustness through diversified prototype ensembles. -
01_waterbirds_with_attribute_annotation.ipynb
Full pipeline demonstration of DPE on the Waterbirds dataset, using group-annotated validation data.
This notebook highlights the effectiveness of training diverse classifiers on balanced group subsets, and evaluates per-group accuracy improvements over the ERM baseline. -
02_celeba_without_attribute_annotation.ipynb
Application of DPE to the CelebA dataset in a more realistic setting where subgroup labels are not available.
It shows that even without group supervision, DPE outperforms strong baselines such as Deep Feature Reweighting (DFR) in worst-group accuracy. The notebook also illustrates that increasing the number of DFR heads does not further improve fairness, while DPE consistently improves both robustness and subgroup equity. -
03_demo.ipynbA streamlined demonstration of the DPE training and evaluation workflow using thedpepackage.
This notebook serves as a minimal working example to illustrate the integration of DPE into an applied training loop on the Waterbirds dataset:
Each notebook is self-contained and can be executed independently. These examples serve as a foundation for adapting DPE to other datasets and deployment scenarios.
Reproducing the Paper Results
This section provides the steps and configuration details needed to reproduce the experiments from our ICML 2025 paper.
Data Preparation
We follow the dataset setup instructions from SubpopBench, which provides scripts and guidelines for preparing all datasets used in our experiments (e.g., Waterbirds, CelebA, MetaShift, MultiNLI).
To prepare the data:
- Follow the instructions in the SubpopBench repository to download and preprocess each dataset.
- Make sure the processed datasets are stored under a common root directory (e.g.,
/datasets). - Set
--data_dirto this root directory when running the training scripts.
Training Pipeline
- Stage-0: Supervised backbone pretraining (ERM or IsoMax).
- Stage-1+: Diverse prototype ensemble training on balanced resampled subsets.
This framework works both with and without access to subgroup annotations.
Stage-0 Training (ERM)
To fine-tune an ImageNet-pretrained ResNet-50 on the MetaShift dataset (located at /datasets/metashift), run:
python main.py \
--epochs 100 \
--loss_name ce \
--dataset_name MetaShift \
--pretrained_imgnet \
--ckpt_dir /checkpoint/ \
--data_dir /datasets
Stage-1+ Training (Diversified Prototypes)
Once Stage-0 is complete, initiate prototype ensemble training using the pretrained backbone:
python main.py \
--dataset_name MetaShift \
--pretrained_path /checkpoint/ckpt_last.pt \
--ckpt_dir /checkpoint \
--loss_name isomax \
--stage 1 \
--num_stages 16 \
--epochs 20 \
--cov_reg 1.e5 \
--batch-size 64 \
--optim sgd \
--lr 1.e-3 \
--train_attr yes \
--train_mode freeze \
--subsample_type group \
--ensemble_criterion wga_val \
--entropic_scale 20 \
-ncbt \
-sit \
Launch All Predefined Jobs
To run all supported configurations for available datasets:
sbatch scripts/train_all.sh
sbatch scripts/train_all_pe.sh
Key Arguments
General
--dataset_name: e.g., Waterbirds, CelebA, MultiNLI, MetaShift--model_name: e.g., resnet50, bert-base-uncased--epochs,--lr: controls training length and learning rate--seed: sets random seed for reproducibility
Stage-0
--loss_name:ce(default)--train_mode:full(default) orfreeze
Stage-1+
--stage 1--pretrained_path: path to Stage-0 model checkpoint--num_stages: number of ensemble heads (default: 16)--cov_reg: strength of inter-prototype similarity penalty--subsample_type:Noneorgroup(group-balanced subsampling if--train_attr yesor class-balanced subsampling if--train_attr no)--entropic_scale: IsoMax temperature scaling factor--train_mode freeze: freeze backbone, train only prototypes-ncbt: disables class-balanced batch construction-sit: enables data shuffling at each epoch--ensemble_criterion: ensemble member selection criterion (e.g.val_wga: based on the best worst group accuracy on the validation set)
Training Tips
- Metric Logging: W&B logs all ensemble-level metrics under the
ensemble_prefix, such asensemble_worst_group_acc. - Covariance Regularization: Tune
--cov_regbetween 1e4 and 1e6 to control prototype diversity. - IsoMax Temperature: Use
--entropic_scalebetween 10 and 40 depending on dataset. - Balanced Sampling:
--subsample_type groupensures subgroup-balanced training when--train_attr yes.--subsample_type classenables class-balanced sampling when--train_attr no.
- Training Schedule:
- Stage-1+ typically converges within 15–30 epochs.
- Output Directory Layout:
- Checkpoints:
/checkpoint/$USER/$SLURM_JOB_ID/ckpt_*.pt - Logs:
logs/<jobname>.<id>.log
- Checkpoints:
- Disabling W&B: Use
--no_wandbto turn off logging for debugging.
Expected Outputs
Stage-0
- Model checkpoints:
ckpt_best_acc.pt,ckpt_best_bal_acc.pt,ckpt_last.pt - Optional feature dumps:
feats_val.npy,feats_test.npy
Stage-1+
- Prototype ensembles:
prototype_ensemble_<criterion>.pt - Distance scale parameters:
dist_scales_<criterion>.pt - Precomputed embeddings:
Auto-saved to the directory specified by--ckpt_dir - Logs and visualizations (if W&B is enabled)
These instructions match the setup used to produce results in our ICML 2025 paper. For additional visual analysis and ablation studies, refer to the Notebooks section.
Results
Worst-group accuracy on datasets without subgroup annotations:
| Algorithm | Waterbirds | CelebA | CivilComments | MultiNLI | MetaShift | CheXpert | ImageNetBG | NICO++ | Living17 |
|---|---|---|---|---|---|---|---|---|---|
| ERM* | 77.9±3.0 | 66.5±2.6 | 69.4±1.2 | 66.5±0.7 | 80.0±0.0 | 75.6±0.4 | 86.4±0.8 | 33.3±0.0 | 53.3±0.9 |
| ERM* + DPE (Ours) | 94.1±0.2 | 84.6±0.8 | 68.9±0.6 | 70.9±0.8 | 83.6±0.9 | 76.8±0.1 | 88.1±0.7 | 50.0±0.0 | 63.0±1.7 |
Worst-group accuracy on datasets with subgroup annotation:
| Algorithm | Group Info (Train / Val) |
WATERBIRDS | CELEBA | CIVILCOMMENTS | MULTINLI | METASHIFT | CHEXPERT |
|---|---|---|---|---|---|---|---|
| ERM* | X / X | 77.9±3.0 | 66.5±2.6 | 69.4±1.2 | 66.5±0.7 | 80.0±0.0 | 75.6±0.4 |
| ERM* + DPE (ours) | X / ✓✓ | 94.1±0.4 | 90.3±0.7 | 70.8±0.8 | 75.3±0.5 | 91.7±1.3 | 76.0±0.3 |
✗: no group info is required
✓: group info is required for hyperparameter tuning
✓✓: validation data is required for training and hyperparameter tuning
More tables and detailed experimental breakdowns are available at:
https://github.com/anonymous102030411/anon
Citation
@article{to2025diverse,
title={Diverse Prototypical Ensembles Improve Robustness to Subpopulation Shift},
author={To, Minh Nguyen Nhat and RWilson, Paul F and Nguyen, Viet and Harmanani, Mohamed and Cooper, Michael and Fooladgar, Fahimeh and Abolmaesumi, Purang and Mousavi, Parvin and Krishnan, Rahul G},
journal={arXiv preprint arXiv:2505.23027},
year={2025}
}
Acknowledgements
Some of the training and evaluation infrastructure in this repository was adapted from:
We thank the authors for releasing their well-organized benchmark and codebase.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file dpe-0.2.0.tar.gz.
File metadata
- Download URL: dpe-0.2.0.tar.gz
- Upload date:
- Size: 29.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c76fb8632e5fad4c29d6455428e7b269678c116b711577376c08d41ba25783e8
|
|
| MD5 |
815e1852bbb5205d425a818131d3f7f2
|
|
| BLAKE2b-256 |
145e23fd050e21efc87f2b698a8be9ad7cfff057025cf00a11ee33785380f5fa
|
File details
Details for the file dpe-0.2.0-py3-none-any.whl.
File metadata
- Download URL: dpe-0.2.0-py3-none-any.whl
- Upload date:
- Size: 24.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dcb8fd7a5515bdda11b35778b7f2e4dcc85fde7639a3cc6e6f06b11293a80b00
|
|
| MD5 |
45a8512ffa0c321e66a3e1353255581e
|
|
| BLAKE2b-256 |
0f096242b8042edbafaa8d2d91e8902d1d6f28cd61fcbefb018bbb757a00e735
|