CEL: Counterfactual Explanations Library - A framework for generating probabilistically plausible counterfactual explanations using normalizing flows
Project description
PPCEF: Probabilistically Plausible Counterfactual Explanations using Normalizing Flows
A Python framework for generating and evaluating counterfactual explanations in machine learning models. The main contribution is PPCEF (Probabilistically Plausible Counterfactual Explanations using Normalizing Flows), a novel method that uses normalizing flows as density estimators within an optimization framework to generate high-quality, plausible counterfactual explanations.
Abstract
We present PPCEF, a novel method specifically tailored for generating probabilistically plausible counterfactual explanations. This approach utilizes normalizing flows as density estimators within an unconstrained optimization framework, effectively balancing distance, validity, and probabilistic plausibility in the produced counterfactuals. Our method is notable for its computational efficiency and ability to process large and high-dimensional datasets, making it particularly applicable in real-world scenarios. A key aspect of PPCEF is its focus on the plausibility of counterfactuals, ensuring that the generated explanations are coherent and realistic within the context of the original data. Through comprehensive experiments across various datasets and models, we demonstrate that PPCEF can successfully generate high-quality counterfactual explanations, highlighting its potential as a valuable tool in enhancing the interpretability and transparency of machine learning systems.
Table of Contents
- Key Features
- Installation
- Quick Start
- Library Structure
- Counterfactual Methods
- Datasets
- Models
- Metrics
- Running Experiments
- Documentation
- Contributing
- Citation
- Contact
Key Features
- Multiple CF Method Families: Local, global, and group counterfactual methods
- Normalizing Flow Integration: State-of-the-art density estimation for plausibility
- Comprehensive Metrics: 17+ evaluation metrics for counterfactual quality
- Hydra Configuration: Flexible experiment management with YAML configs
- 21 Built-in Datasets: Classification and regression tasks
- Extensible Architecture: Easy to add new methods, models, and metrics
- PyTorch-based: Modern deep learning framework
- Cross-validation Support: Robust evaluation with k-fold CV
- Preprocessing Pipeline: Composable feature transformations
Installation
Clone the repository and set up the environment:
git clone git@github.com:ofurman/counterfactuals.git
cd counterfactuals
./setup_env.sh
Or install dependencies manually with uv:
uv sync
Requirements: Python >= 3.10
Quick Start
import torch
from counterfactuals.datasets import MethodDataset
from counterfactuals.cf_methods import PPCEF
from counterfactuals.models import MaskedAutoregressiveFlow, MLPClassifier
from counterfactuals.losses import BinaryDiscLoss
from counterfactuals.metrics import evaluate_cf
# Load dataset with preprocessing
dataset = MethodDataset.from_config("config/datasets/moons.yaml")
train_loader = dataset.train_dataloader(batch_size=128, shuffle=True)
test_loader = dataset.test_dataloader(batch_size=128, shuffle=False)
# Train discriminative model (classifier)
disc_model = MLPClassifier(
input_size=dataset.input_size,
hidden_layer_sizes=[256, 256],
target_size=1,
dropout=0.2,
)
disc_model.fit(train_loader, test_loader, epochs=5000, patience=300, lr=1e-3)
# Train generative model (normalizing flow)
gen_model = MaskedAutoregressiveFlow(
features=dataset.input_size,
hidden_features=8,
context_features=1,
)
gen_model.fit(train_loader, test_loader, num_epochs=1000)
# Generate counterfactuals
cf_method = PPCEF(
gen_model=gen_model,
disc_model=disc_model,
disc_model_criterion=BinaryDiscLoss(),
)
log_prob_threshold = torch.quantile(gen_model.predict_log_prob(test_loader), 0.25)
result = cf_method.explain_dataloader(
test_loader,
alpha=100,
log_prob_threshold=log_prob_threshold,
epochs=4000,
)
# Evaluate results
X_cf = result.x_origs + result.x_cfs
metrics = evaluate_cf(
disc_model=disc_model,
gen_model=gen_model,
X_cf=X_cf,
X_train=dataset.X_train,
y_train=dataset.y_train,
X_test=result.x_origs,
y_test=result.y_origs,
y_target=result.y_cf_targets,
continuous_features=dataset.numerical_features,
categorical_features=dataset.categorical_features,
median_log_prob=log_prob_threshold,
)
Jupyter Notebooks
Example notebooks are available in notebooks/:
| Notebook | Description |
|---|---|
ppcef.ipynb |
Basic PPCEF usage |
rppcef.ipynb |
Regional PPCEF for group explanations |
categorical_ppcef.ipynb |
Handling categorical features |
toy_example.ipynb |
Simple visualization examples |
Library Structure
counterfactuals/
├── cf_methods/ # Counterfactual explanation methods
│ ├── local/ # Instance-level methods (PPCEF, DiCE, WACH, etc.)
│ ├── global_/ # Model-level methods (GLOBE-CE, AReS)
│ └── group/ # Cohort-level methods (RPPCEF, GLANCE)
├── models/ # ML models
│ ├── discriminative/ # Classifiers (MLP, LogisticRegression, NODE)
│ ├── generative/ # Density estimators (MAF, RealNVP, NICE, KDE)
│ └── regression/ # Regressors (MLP, LinearRegression)
├── datasets/ # Dataset loading and configuration
├── preprocessing/ # Feature transformation pipeline
├── dequantization/ # Categorical feature handling for flows
├── losses/ # Loss functions for CF optimization
├── metrics/ # Evaluation metrics
├── pipelines/ # Experiment orchestration
│ ├── nodes/ # Pipeline components
│ └── conf/ # Hydra configuration files
├── plotting/ # Visualization utilities
└── utils.py # Helper functions
config/
└── datasets/ # Dataset YAML configurations (21 datasets)
docs/
├── library_overview.md # Comprehensive package documentation
└── ppcef_pipeline.md # Pipeline guide
Counterfactual Methods
Local Methods (Instance-level)
| Method | Class | Description |
|---|---|---|
| PPCEF | PPCEF |
Probabilistically Plausible CF with normalizing flows (main contribution) |
| PPCEFR | PPCEFR |
PPCEF for regression tasks |
| DiCE | DICE |
Diverse Counterfactual Explanations |
| CEM | CEM_CF |
Contrastive Explanation Method |
| CET | CET |
Counterfactual Explanation Tree |
| WACH | WACH |
Wachter-style gradient-based CF |
| Artelt | Artelt |
Artelt's CF method |
| SACE | SACE, CaseBasedSACE |
(Case-based) SACE methods |
| CEGP | CEGP |
CF with Gaussian Processes |
| C-CHVAE | CCHVAE |
Conditional Heterogeneous VAE |
| DiCoFlex | DiCoFlex |
Diverse Counterfactual Flex |
| LiCE | LiCE |
LIME-style CF (requires pyomo/onnx/omlt) |
Global Methods (Model-level)
| Method | Class | Description |
|---|---|---|
| GLOBE-CE | GLOBE_CE |
Global Counterfactual Explanations |
| AReS | AReS |
Actionable Recourse Summaries |
Group Methods (Cohort-level)
| Method | Class | Description |
|---|---|---|
| RPPCEF | RPPCEF |
Regional PPCEF with shared interventions |
| GLANCE | GLANCE |
Group-level CF method |
Datasets
The library includes 21 pre-configured datasets:
Classification:
adult, adult_census, audit, bank_marketing, compas, credit_default, diabetes, digits, german_credit, give_me_some_credit, heloc, law, lending_club, mnist, moons, wine, blobs
Regression:
concrete, toy_regression, wine_quality_regression, yacht
Dataset configurations are in config/datasets/*.yaml and support:
- Automatic feature type detection (continuous/categorical)
- Actionability flags for features
- Cross-validation splits
- Train/test split configuration
Models
Discriminative Models
| Model | Class | Use Case |
|---|---|---|
| MLP Classifier | MLPClassifier |
General classification |
| Logistic Regression | LogisticRegression |
Binary classification |
| Multinomial LR | MultinomialLogisticRegression |
Multiclass |
| NODE | NODE |
Neural Oblivious Decision Ensembles |
Generative Models
| Model | Class | Description |
|---|---|---|
| MAF | MaskedAutoregressiveFlow |
Primary normalizing flow |
| RealNVP | RealNVP |
Real-valued Non-Volume Preserving |
| NICE | NICE |
Non-linear Independent Components |
| KDE | KDE |
Kernel Density Estimation baseline |
Regression Models
| Model | Class |
|---|---|
| MLP Regressor | MLPRegressor |
| Linear Regression | LinearRegression |
Metrics
The library provides comprehensive evaluation metrics:
| Category | Metrics |
|---|---|
| Validity | coverage, validity, actionability |
| Sparsity | sparsity |
| Distance | proximity_continuous_euclidean, proximity_continuous_manhattan, proximity_continuous_mad, proximity_categorical_hamming, proximity_categorical_jaccard, proximity_l2_jaccard, proximity_mad_hamming |
| Plausibility | prob_plausibility, log_density_cf, log_density_test |
| Outlier Detection | lof_scores_cf, lof_scores_test, isolation_forest_scores_cf, isolation_forest_scores_test |
Running Experiments
Using Hydra Pipelines
# Run PPCEF pipeline
uv run python counterfactuals/pipelines/run_ppcef_pipeline.py
# With custom configuration
uv run python counterfactuals/pipelines/run_ppcef_pipeline.py \
dataset.config_path=config/datasets/heloc.yaml \
disc_model.model=disc_model/mlp_large \
counterfactuals_params.target_class=1
Available Pipelines
| Pipeline | Method |
|---|---|
run_ppcef_pipeline.py |
PPCEF (main) |
run_ppcefr_pipeline.py |
PPCEF for regression |
run_rppcef_pipeline.py |
Regional PPCEF |
run_dice_pipeline.py |
DiCE |
run_cem_pipeline.py |
CEM |
run_cet_pipeline.py |
CET |
run_cchvae_pipeline.py |
C-CHVAE |
run_wach_pipeline.py |
WACH |
run_artelt_pipeline.py |
Artelt |
run_cegp_pipeline.py |
CEGP |
run_globe_ce_pipeline.py |
GLOBE-CE |
run_ares_pipeline.py |
AReS |
run_glance_pipeline.py |
GLANCE |
Documentation
docs/library_overview.md- Comprehensive package documentationdocs/ppcef_pipeline.md- Detailed PPCEF pipeline guideAGENTS.md- Development guidelines and coding standards
Contributing
Contributions are welcome! Before opening a PR:
- Read
AGENTS.mdanddocs/ppcef_pipeline.mdto understand the workflow - Use
uvfor all operations:uv sync # Install dependencies uv run ruff check --fix # Lint and fix uv run pytest # Run tests
- Follow the coding standards:
- Python 3.10+, PEP 8 compliant
- Full type hints everywhere
- Google-style docstrings
- Line length: 100 characters
- Keep patches small and well-documented
- Update or add tests when behavior changes
To add new dependencies:
uv add <package>
Citation
@inbook{inbook,
author = {Wielopolski, Patryk and Furman, Oleksii and Stefanowski, Jerzy and Zieba, Maciej},
year = {2024},
month = {10},
pages = {},
title = {Probabilistically Plausible Counterfactual Explanations with Normalizing Flows},
isbn = {9781643685489},
doi = {10.3233/FAIA240584}
}
Contact
For questions or comments, please contact via LinkedIn: TBA
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ce_library-0.1.0.tar.gz.
File metadata
- Download URL: ce_library-0.1.0.tar.gz
- Upload date:
- Size: 225.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7751b37f842ca795569fe01fd55a66273431973be094e9a27e0159ab3ddeba00
|
|
| MD5 |
04b09a213f23996707e5c3256192450a
|
|
| BLAKE2b-256 |
7464d90bc9a9311ec9a365477efc301cc942216a7da5a1c8a084dbea94d012bc
|
File details
Details for the file ce_library-0.1.0-py3-none-any.whl.
File metadata
- Download URL: ce_library-0.1.0-py3-none-any.whl
- Upload date:
- Size: 351.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ed5e05afddf7d81d17687a178da7d806fe8fb3e36fe72a3b009b1229623bdd6a
|
|
| MD5 |
3a921ab6018d3060d15e9490a67e6b27
|
|
| BLAKE2b-256 |
521da0484f7da1baf9aad6712dfdbb9c4f7e679a38ade3ebe36e2e7086120db0
|