Skip to main content

CEL: Counterfactual Explanations Library

Project description

CEL: Counterfactual Explanations Library

A comprehensive Python framework for generating and evaluating counterfactual explanations in machine learning models. CEL (Counterfactual Explanations Library) provides a unified interface for multiple state-of-the-art counterfactual methods, including local (instance-level), global (model-level), and group (cohort-level) approaches.

Overview

Counterfactual explanations offer a way to understand machine learning model decisions by explaining what minimal changes would alter a prediction. This library provides a unified framework for generating, evaluating, and comparing different counterfactual explanation methods across various datasets and model types.

The library includes multiple counterfactual methods, from gradient-based approaches like Wachter to advanced methods using normalizing flows for density estimation. It emphasizes plausibility, ensuring that generated explanations are coherent and realistic within the context of the original data.

Table of Contents

Key Features

  • Multiple CF Method Families: Local, global, and group counterfactual methods
  • Normalizing Flow Integration: State-of-the-art density estimation for plausibility
  • Comprehensive Metrics: 17+ evaluation metrics for counterfactual quality
  • Hydra Configuration: Flexible experiment management with YAML configs
  • 18 Built-in Datasets: Classification and regression tasks
  • Extensible Architecture: Easy to add new methods, models, and metrics
  • PyTorch-based: Modern deep learning framework
  • Cross-validation Support: Robust evaluation with k-fold CV
  • Preprocessing Pipeline: Composable feature transformations

Installation

Clone the repository and set up the environment:

git clone git@github.com:ofurman/cel.git
cd cel
./setup_env.sh

Or install dependencies manually with uv:

uv sync

Requirements: Python >= 3.10

Quick Start

import torch
from cel.datasets import FileDataset, MethodDataset
from cel.cf_methods import PPCEF
from cel.models import MaskedAutoregressiveFlow, MLPClassifier
from cel.losses import BinaryDiscLoss
from cel.metrics.orchestrator import MetricsOrchestrator
import numpy as np

# Note: You may need to convert data to float32 for PyTorch compatibility
# X_train = dataset.X_train.astype(np.float32)
# y_train = dataset.y_train.astype(np.float32)

# Load dataset with preprocessing
file_dataset = FileDataset(config_path="config/datasets/moons.yaml")
dataset = MethodDataset(file_dataset=file_dataset)
train_loader = dataset.train_dataloader(batch_size=128, shuffle=True)
test_loader = dataset.test_dataloader(batch_size=128, shuffle=False)

# Train discriminative model (classifier)
disc_model = MLPClassifier(
    num_inputs=dataset.X_train.shape[1],
    num_targets=1,
    hidden_layer_sizes=[256, 256],
    dropout=0.2,
)
disc_model.fit(train_loader, test_loader, epochs=5000, patience=300, lr=1e-3)

# Train generative model (normalizing flow)
gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1],
    hidden_features=8,
    context_features=1,
)
gen_model.fit(train_loader, test_loader, epochs=1000)

# Generate counterfactuals
cf_method = PPCEF(
    gen_model=gen_model,
    disc_model=disc_model,
    disc_model_criterion=BinaryDiscLoss(),
)
log_prob_threshold = torch.quantile(gen_model.predict_log_prob(test_loader), 0.25)
result = cf_method.explain_dataloader(
    test_loader,
    alpha=100,
    log_prob_threshold=log_prob_threshold,
    epochs=4000,
)

# Evaluate results using MetricsOrchestrator
orchestrator = MetricsOrchestrator(
    X_cf=result.x_cfs,
    y_target=result.y_cf_targets,
    X_train=dataset.X_train,
    y_train=dataset.y_train,
    X_test=result.x_origs,
    y_test=result.y_origs,
    gen_model=gen_model,
    disc_model=disc_model,
    continuous_features=dataset.numerical_features_indices,
    categorical_features=dataset.categorical_features_indices,
    prob_plausibility_threshold=log_prob_threshold,
    metrics_conf_path="cel/pipelines/conf/metrics/default.yaml",
)
metrics = orchestrator.calculate_all_metrics()

Library Structure

counterfactuals/
├── cf_methods/           # Counterfactual explanation methods
│   ├── local/            # Instance-level methods (PPCEF, DiCE, WACH, etc.)
│   ├── global_/          # Model-level methods (GLOBE-CE, AReS)
│   └── group/            # Cohort-level methods (GLANCE, T-CREx)
├── models/               # ML models
│   ├── discriminative/   # Classifiers (MLP, LogisticRegression, NODE)
│   ├── generative/       # Density estimators (MAF, RealNVP, NICE, KDE)
│   └── regression/       # Regressors (MLP, LinearRegression)
├── datasets/             # Dataset loading and configuration
├── preprocessing/        # Feature transformation pipeline
├── dequantization/       # Categorical feature handling for flows
├── losses/               # Loss functions for CF optimization
├── metrics/              # Evaluation metrics
├── pipelines/            # Experiment orchestration
│   ├── nodes/            # Pipeline components
│   └── conf/             # Hydra configuration files
├── plotting/             # Visualization utilities
└── utils.py              # Helper functions

config/
└── datasets/             # Dataset YAML configurations (18 datasets)

docs/
├── library_overview.md   # Comprehensive package documentation
└── ppcef_pipeline.md     # Pipeline guide

Counterfactual Methods

Local Methods (Instance-level)

Method Class Description
WACH WACH Wachter-style gradient-based CF
Artelt Artelt Heuristic-based CF method
DiCE DICE Diverse Counterfactual Explanations
CCHVAE CCHVAE Conditional Heterogeneous VAE
PPCEF PPCEF Probabilistically Plausible CF with normalizing flows
CEM CEM_CF Contrastive Explanation Method
CEGP CEGP Counterfactual with Gaussian Processes
CADEX CADEX Counterfactual explanations via optimization
SACE SACE Several SACE variants
CEARM CEARM Counterfactual explanation through association rule mining

Global Methods (Model-level)

Method Class Description
GLOBE-CE GLOBE_CE Global Counterfactual Explanations
AReS AReS Actionable Recourse Summaries

Group Methods (Cohort-level)

Method Class Description
GLANCE GLANCE Group-level CF method
T-CREx TCREx Temporal Counterfactual Rule Extraction

Datasets

The library includes 18 pre-configured datasets:

Classification (13): adult_census, audit, bank_marketing, blobs, credit_default, digits, german_credit, give_me_some_credit (GMC), heloc, law, lending_club, moons, wine

Regression (5): concrete, diabetes, yacht, synthetic, scm20d

Dataset configurations are in config/datasets/*.yaml and support:

  • Automatic feature type detection (continuous/categorical)
  • Actionability flags for features
  • Cross-validation splits
  • Train/test split configuration

Models

Discriminative Models

Model Class Use Case
MLP Classifier MLPClassifier General classification
Logistic Regression LogisticRegression Binary classification
Multinomial LR MultinomialLogisticRegression Multiclass
NODE NODE Neural Oblivious Decision Ensembles

Generative Models

Model Class Description
MAF MaskedAutoregressiveFlow Primary normalizing flow
RealNVP RealNVP Real-valued Non-Volume Preserving
NICE NICE Non-linear Independent Components
KDE KDE Kernel Density Estimation baseline

Regression Models

Model Class
MLP Regressor MLPRegressor
Linear Regression LinearRegression

Metrics

The library provides comprehensive evaluation metrics:

Category Metrics
Validity coverage, validity, actionability
Sparsity sparsity
Distance proximity_continuous_euclidean, proximity_continuous_manhattan, proximity_continuous_mad, proximity_categorical_hamming, proximity_categorical_jaccard, proximity_l2_jaccard, proximity_mad_hamming
Plausibility prob_plausibility, log_density_cf, log_density_test
Outlier Detection lof_scores_cf, lof_scores_test, isolation_forest_scores_cf, isolation_forest_scores_test

Running Experiments

Using Hydra Pipelines

# Run PPCEF pipeline
uv run python cel/pipelines/run_ppcef_pipeline.py

# With custom configuration
uv run python cel/pipelines/run_ppcef_pipeline.py \
  dataset.config_path=config/datasets/heloc.yaml \
  disc_model.model=disc_model/mlp_large \
  counterfactuals_params.target_class=1

Available Pipelines

Pipeline Method
run_ppcef_pipeline.py PPCEF
run_dice_pipeline.py DiCE
run_cem_pipeline.py CEM
run_cchvae_pipeline.py C-CHVAE
run_wach_pipeline.py WACH
run_artelt_pipeline.py Artelt
run_cegp_pipeline.py CEGP
run_cadex_pipeline.py CADEX
run_sace_pipeline.py SACE
run_cearm_pipeline.py CEARM
run_globe_ce_pipeline.py GLOBE-CE
run_ares_pipeline.py AReS
run_glance_pipeline.py GLANCE
run_tcrex_pipeline.py T-CREx

Documentation

Live Docs: https://ofurman.github.io/counterfactuals/

Contributing

Contributions are welcome! Before opening a PR:

  1. Read AGENTS.md and docs/ppcef_pipeline.md to understand the workflow
  2. Use uv for all operations:
    uv sync                     # Install dependencies
    uv run ruff check --fix     # Lint and fix
    uv run pytest               # Run tests
    
  3. Follow the coding standards:
    • Python 3.10+, PEP 8 compliant
    • Full type hints everywhere
    • Google-style docstrings
    • Line length: 100 characters
  4. Keep patches small and well-documented
  5. Update or add tests when behavior changes

To add new dependencies:

uv add <package>

Contact

For questions or comments, please contact via LinkedIn: TBA

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ce_library-0.1.1.tar.gz (174.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ce_library-0.1.1-py3-none-any.whl (281.7 kB view details)

Uploaded Python 3

File details

Details for the file ce_library-0.1.1.tar.gz.

File metadata

  • Download URL: ce_library-0.1.1.tar.gz
  • Upload date:
  • Size: 174.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for ce_library-0.1.1.tar.gz
Algorithm Hash digest
SHA256 43cf367a2931c4dde0eb6256e1f0a65da019028963eac00c9ca88eae0d2013e3
MD5 56546b298d806d030ef29a308ab889f9
BLAKE2b-256 8ba080f39363763cf5566f133652f89c6338707f639d6c8f78ca1d13e88c2fb1

See more details on using hashes here.

File details

Details for the file ce_library-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: ce_library-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 281.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for ce_library-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 07c304b9bf20f0119fccea2b0f7159a87f36f5b42e0a8ad56afc03c26f5dbb87
MD5 f8a59ade2ba76be575f2fa00f48fed05
BLAKE2b-256 10e5e516667b4b15f25c7483a19780d30f3a1774343452a21f630b96487416d1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page