Skip to main content

CEL: Counterfactual Explanations Library - A framework for generating probabilistically plausible counterfactual explanations using normalizing flows

Project description

PPCEF: Probabilistically Plausible Counterfactual Explanations using Normalizing Flows

A Python framework for generating and evaluating counterfactual explanations in machine learning models. The main contribution is PPCEF (Probabilistically Plausible Counterfactual Explanations using Normalizing Flows), a novel method that uses normalizing flows as density estimators within an optimization framework to generate high-quality, plausible counterfactual explanations.

PPCEF Framework Overview

Abstract

We present PPCEF, a novel method specifically tailored for generating probabilistically plausible counterfactual explanations. This approach utilizes normalizing flows as density estimators within an unconstrained optimization framework, effectively balancing distance, validity, and probabilistic plausibility in the produced counterfactuals. Our method is notable for its computational efficiency and ability to process large and high-dimensional datasets, making it particularly applicable in real-world scenarios. A key aspect of PPCEF is its focus on the plausibility of counterfactuals, ensuring that the generated explanations are coherent and realistic within the context of the original data. Through comprehensive experiments across various datasets and models, we demonstrate that PPCEF can successfully generate high-quality counterfactual explanations, highlighting its potential as a valuable tool in enhancing the interpretability and transparency of machine learning systems.

Table of Contents

Key Features

  • Multiple CF Method Families: Local, global, and group counterfactual methods
  • Normalizing Flow Integration: State-of-the-art density estimation for plausibility
  • Comprehensive Metrics: 17+ evaluation metrics for counterfactual quality
  • Hydra Configuration: Flexible experiment management with YAML configs
  • 21 Built-in Datasets: Classification and regression tasks
  • Extensible Architecture: Easy to add new methods, models, and metrics
  • PyTorch-based: Modern deep learning framework
  • Cross-validation Support: Robust evaluation with k-fold CV
  • Preprocessing Pipeline: Composable feature transformations

Installation

Clone the repository and set up the environment:

git clone git@github.com:ofurman/counterfactuals.git
cd counterfactuals
./setup_env.sh

Or install dependencies manually with uv:

uv sync

Requirements: Python >= 3.10

Quick Start

import torch
from counterfactuals.datasets import MethodDataset
from counterfactuals.cf_methods import PPCEF
from counterfactuals.models import MaskedAutoregressiveFlow, MLPClassifier
from counterfactuals.losses import BinaryDiscLoss
from counterfactuals.metrics import evaluate_cf

# Load dataset with preprocessing
dataset = MethodDataset.from_config("config/datasets/moons.yaml")
train_loader = dataset.train_dataloader(batch_size=128, shuffle=True)
test_loader = dataset.test_dataloader(batch_size=128, shuffle=False)

# Train discriminative model (classifier)
disc_model = MLPClassifier(
    input_size=dataset.input_size,
    hidden_layer_sizes=[256, 256],
    target_size=1,
    dropout=0.2,
)
disc_model.fit(train_loader, test_loader, epochs=5000, patience=300, lr=1e-3)

# Train generative model (normalizing flow)
gen_model = MaskedAutoregressiveFlow(
    features=dataset.input_size,
    hidden_features=8,
    context_features=1,
)
gen_model.fit(train_loader, test_loader, num_epochs=1000)

# Generate counterfactuals
cf_method = PPCEF(
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=BinaryDiscLoss(),
)
log_prob_threshold = torch.quantile(gen_model.predict_log_prob(test_loader), 0.25)
result = cf_method.explain_dataloader(
    test_loader,
    alpha=100,
    log_prob_threshold=log_prob_threshold,
    epochs=4000,
)

# Evaluate results
X_cf = result.x_origs + result.x_cfs
metrics = evaluate_cf(
    disc_model=disc_model,
    gen_model=gen_model,
    X_cf=X_cf,
    X_train=dataset.X_train,
    y_train=dataset.y_train,
    X_test=result.x_origs,
    y_test=result.y_origs,
    y_target=result.y_cf_targets,
    continuous_features=dataset.numerical_features,
    categorical_features=dataset.categorical_features,
    median_log_prob=log_prob_threshold,
)

Jupyter Notebooks

Example notebooks are available in notebooks/:

Notebook Description
ppcef.ipynb Basic PPCEF usage
rppcef.ipynb Regional PPCEF for group explanations
categorical_ppcef.ipynb Handling categorical features
toy_example.ipynb Simple visualization examples

Library Structure

counterfactuals/
├── cf_methods/           # Counterfactual explanation methods
│   ├── local/            # Instance-level methods (PPCEF, DiCE, WACH, etc.)
│   ├── global_/          # Model-level methods (GLOBE-CE, AReS)
│   └── group/            # Cohort-level methods (RPPCEF, GLANCE)
├── models/               # ML models
│   ├── discriminative/   # Classifiers (MLP, LogisticRegression, NODE)
│   ├── generative/       # Density estimators (MAF, RealNVP, NICE, KDE)
│   └── regression/       # Regressors (MLP, LinearRegression)
├── datasets/             # Dataset loading and configuration
├── preprocessing/        # Feature transformation pipeline
├── dequantization/       # Categorical feature handling for flows
├── losses/               # Loss functions for CF optimization
├── metrics/              # Evaluation metrics
├── pipelines/            # Experiment orchestration
│   ├── nodes/            # Pipeline components
│   └── conf/             # Hydra configuration files
├── plotting/             # Visualization utilities
└── utils.py              # Helper functions

config/
└── datasets/             # Dataset YAML configurations (21 datasets)

docs/
├── library_overview.md   # Comprehensive package documentation
└── ppcef_pipeline.md     # Pipeline guide

Counterfactual Methods

Local Methods (Instance-level)

Method Class Description
PPCEF PPCEF Probabilistically Plausible CF with normalizing flows (main contribution)
PPCEFR PPCEFR PPCEF for regression tasks
DiCE DICE Diverse Counterfactual Explanations
CEM CEM_CF Contrastive Explanation Method
CET CET Counterfactual Explanation Tree
WACH WACH Wachter-style gradient-based CF
Artelt Artelt Artelt's CF method
SACE SACE, CaseBasedSACE (Case-based) SACE methods
CEGP CEGP CF with Gaussian Processes
C-CHVAE CCHVAE Conditional Heterogeneous VAE
DiCoFlex DiCoFlex Diverse Counterfactual Flex
LiCE LiCE LIME-style CF (requires pyomo/onnx/omlt)

Global Methods (Model-level)

Method Class Description
GLOBE-CE GLOBE_CE Global Counterfactual Explanations
AReS AReS Actionable Recourse Summaries

Group Methods (Cohort-level)

Method Class Description
RPPCEF RPPCEF Regional PPCEF with shared interventions
GLANCE GLANCE Group-level CF method

Datasets

The library includes 21 pre-configured datasets:

Classification: adult, adult_census, audit, bank_marketing, compas, credit_default, diabetes, digits, german_credit, give_me_some_credit, heloc, law, lending_club, mnist, moons, wine, blobs

Regression: concrete, toy_regression, wine_quality_regression, yacht

Dataset configurations are in config/datasets/*.yaml and support:

  • Automatic feature type detection (continuous/categorical)
  • Actionability flags for features
  • Cross-validation splits
  • Train/test split configuration

Models

Discriminative Models

Model Class Use Case
MLP Classifier MLPClassifier General classification
Logistic Regression LogisticRegression Binary classification
Multinomial LR MultinomialLogisticRegression Multiclass
NODE NODE Neural Oblivious Decision Ensembles

Generative Models

Model Class Description
MAF MaskedAutoregressiveFlow Primary normalizing flow
RealNVP RealNVP Real-valued Non-Volume Preserving
NICE NICE Non-linear Independent Components
KDE KDE Kernel Density Estimation baseline

Regression Models

Model Class
MLP Regressor MLPRegressor
Linear Regression LinearRegression

Metrics

The library provides comprehensive evaluation metrics:

Category Metrics
Validity coverage, validity, actionability
Sparsity sparsity
Distance proximity_continuous_euclidean, proximity_continuous_manhattan, proximity_continuous_mad, proximity_categorical_hamming, proximity_categorical_jaccard, proximity_l2_jaccard, proximity_mad_hamming
Plausibility prob_plausibility, log_density_cf, log_density_test
Outlier Detection lof_scores_cf, lof_scores_test, isolation_forest_scores_cf, isolation_forest_scores_test

Running Experiments

Using Hydra Pipelines

# Run PPCEF pipeline
uv run python counterfactuals/pipelines/run_ppcef_pipeline.py

# With custom configuration
uv run python counterfactuals/pipelines/run_ppcef_pipeline.py \
  dataset.config_path=config/datasets/heloc.yaml \
  disc_model.model=disc_model/mlp_large \
  counterfactuals_params.target_class=1

Available Pipelines

Pipeline Method
run_ppcef_pipeline.py PPCEF (main)
run_ppcefr_pipeline.py PPCEF for regression
run_rppcef_pipeline.py Regional PPCEF
run_dice_pipeline.py DiCE
run_cem_pipeline.py CEM
run_cet_pipeline.py CET
run_cchvae_pipeline.py C-CHVAE
run_wach_pipeline.py WACH
run_artelt_pipeline.py Artelt
run_cegp_pipeline.py CEGP
run_globe_ce_pipeline.py GLOBE-CE
run_ares_pipeline.py AReS
run_glance_pipeline.py GLANCE

Documentation

Contributing

Contributions are welcome! Before opening a PR:

  1. Read AGENTS.md and docs/ppcef_pipeline.md to understand the workflow
  2. Use uv for all operations:
    uv sync                     # Install dependencies
    uv run ruff check --fix     # Lint and fix
    uv run pytest               # Run tests
    
  3. Follow the coding standards:
    • Python 3.10+, PEP 8 compliant
    • Full type hints everywhere
    • Google-style docstrings
    • Line length: 100 characters
  4. Keep patches small and well-documented
  5. Update or add tests when behavior changes

To add new dependencies:

uv add <package>

Citation

@inbook{inbook,
  author = {Wielopolski, Patryk and Furman, Oleksii and Stefanowski, Jerzy and Zieba, Maciej},
  year = {2024},
  month = {10},
  pages = {},
  title = {Probabilistically Plausible Counterfactual Explanations with Normalizing Flows},
  isbn = {9781643685489},
  doi = {10.3233/FAIA240584}
}

Contact

For questions or comments, please contact via LinkedIn: TBA

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ce_library-0.1.0.tar.gz (225.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ce_library-0.1.0-py3-none-any.whl (351.2 kB view details)

Uploaded Python 3

File details

Details for the file ce_library-0.1.0.tar.gz.

File metadata

  • Download URL: ce_library-0.1.0.tar.gz
  • Upload date:
  • Size: 225.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.7

File hashes

Hashes for ce_library-0.1.0.tar.gz
Algorithm Hash digest
SHA256 7751b37f842ca795569fe01fd55a66273431973be094e9a27e0159ab3ddeba00
MD5 04b09a213f23996707e5c3256192450a
BLAKE2b-256 7464d90bc9a9311ec9a365477efc301cc942216a7da5a1c8a084dbea94d012bc

See more details on using hashes here.

File details

Details for the file ce_library-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: ce_library-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 351.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.7

File hashes

Hashes for ce_library-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ed5e05afddf7d81d17687a178da7d806fe8fb3e36fe72a3b009b1229623bdd6a
MD5 3a921ab6018d3060d15e9490a67e6b27
BLAKE2b-256 521da0484f7da1baf9aad6712dfdbb9c4f7e679a38ade3ebe36e2e7086120db0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page