CEL: Counterfactual Explanations Library
Project description
CEL: Counterfactual Explanations Library
A comprehensive Python framework for generating and evaluating counterfactual explanations in machine learning models. CEL (Counterfactual Explanations Library) provides a unified interface for multiple state-of-the-art counterfactual methods, including local (instance-level), global (model-level), and group (cohort-level) approaches.
Overview
Counterfactual explanations offer a way to understand machine learning model decisions by explaining what minimal changes would alter a prediction. This library provides a unified framework for generating, evaluating, and comparing different counterfactual explanation methods across various datasets and model types.
The library includes multiple counterfactual methods, from gradient-based approaches like Wachter to advanced methods using normalizing flows for density estimation. It emphasizes plausibility, ensuring that generated explanations are coherent and realistic within the context of the original data.
Table of Contents
- Key Features
- Installation
- Quick Start
- Library Structure
- Counterfactual Methods
- Datasets
- Models
- Metrics
- Running Experiments
- Documentation
- Contributing
- Citation
- Contact
Key Features
- Multiple CF Method Families: Local, global, and group counterfactual methods
- Normalizing Flow Integration: State-of-the-art density estimation for plausibility
- Comprehensive Metrics: 17+ evaluation metrics for counterfactual quality
- Hydra Configuration: Flexible experiment management with YAML configs
- 18 Built-in Datasets: Classification and regression tasks
- Extensible Architecture: Easy to add new methods, models, and metrics
- PyTorch-based: Modern deep learning framework
- Cross-validation Support: Robust evaluation with k-fold CV
- Preprocessing Pipeline: Composable feature transformations
Installation
Clone the repository and set up the environment:
git clone git@github.com:ofurman/cel.git
cd cel
./setup_env.sh
Or install dependencies manually with uv:
uv sync
Requirements: Python >= 3.10
Quick Start
import torch
from cel.datasets import FileDataset, MethodDataset
from cel.cf_methods import PPCEF
from cel.models import MaskedAutoregressiveFlow, MLPClassifier
from cel.losses import BinaryDiscLoss
from cel.metrics.orchestrator import MetricsOrchestrator
import numpy as np
# Note: You may need to convert data to float32 for PyTorch compatibility
# X_train = dataset.X_train.astype(np.float32)
# y_train = dataset.y_train.astype(np.float32)
# Load dataset with preprocessing
file_dataset = FileDataset(config_path="config/datasets/moons.yaml")
dataset = MethodDataset(file_dataset=file_dataset)
train_loader = dataset.train_dataloader(batch_size=128, shuffle=True)
test_loader = dataset.test_dataloader(batch_size=128, shuffle=False)
# Train discriminative model (classifier)
disc_model = MLPClassifier(
num_inputs=dataset.X_train.shape[1],
num_targets=1,
hidden_layer_sizes=[256, 256],
dropout=0.2,
)
disc_model.fit(train_loader, test_loader, epochs=5000, patience=300, lr=1e-3)
# Train generative model (normalizing flow)
gen_model = MaskedAutoregressiveFlow(
features=dataset.X_train.shape[1],
hidden_features=8,
context_features=1,
)
gen_model.fit(train_loader, test_loader, epochs=1000)
# Generate counterfactuals
cf_method = PPCEF(
gen_model=gen_model,
disc_model=disc_model,
disc_model_criterion=BinaryDiscLoss(),
)
log_prob_threshold = torch.quantile(gen_model.predict_log_prob(test_loader), 0.25)
result = cf_method.explain_dataloader(
test_loader,
alpha=100,
log_prob_threshold=log_prob_threshold,
epochs=4000,
)
# Evaluate results using MetricsOrchestrator
orchestrator = MetricsOrchestrator(
X_cf=result.x_cfs,
y_target=result.y_cf_targets,
X_train=dataset.X_train,
y_train=dataset.y_train,
X_test=result.x_origs,
y_test=result.y_origs,
gen_model=gen_model,
disc_model=disc_model,
continuous_features=dataset.numerical_features_indices,
categorical_features=dataset.categorical_features_indices,
prob_plausibility_threshold=log_prob_threshold,
metrics_conf_path="cel/pipelines/conf/metrics/default.yaml",
)
metrics = orchestrator.calculate_all_metrics()
Library Structure
counterfactuals/
├── cf_methods/ # Counterfactual explanation methods
│ ├── local/ # Instance-level methods (PPCEF, DiCE, WACH, etc.)
│ ├── global_/ # Model-level methods (GLOBE-CE, AReS)
│ └── group/ # Cohort-level methods (GLANCE, T-CREx)
├── models/ # ML models
│ ├── discriminative/ # Classifiers (MLP, LogisticRegression, NODE)
│ ├── generative/ # Density estimators (MAF, RealNVP, NICE, KDE)
│ └── regression/ # Regressors (MLP, LinearRegression)
├── datasets/ # Dataset loading and configuration
├── preprocessing/ # Feature transformation pipeline
├── dequantization/ # Categorical feature handling for flows
├── losses/ # Loss functions for CF optimization
├── metrics/ # Evaluation metrics
├── pipelines/ # Experiment orchestration
│ ├── nodes/ # Pipeline components
│ └── conf/ # Hydra configuration files
├── plotting/ # Visualization utilities
└── utils.py # Helper functions
config/
└── datasets/ # Dataset YAML configurations (18 datasets)
docs/
├── library_overview.md # Comprehensive package documentation
└── ppcef_pipeline.md # Pipeline guide
Counterfactual Methods
Local Methods (Instance-level)
| Method | Class | Description |
|---|---|---|
| WACH | WACH |
Wachter-style gradient-based CF |
| Artelt | Artelt |
Heuristic-based CF method |
| DiCE | DICE |
Diverse Counterfactual Explanations |
| CCHVAE | CCHVAE |
Conditional Heterogeneous VAE |
| PPCEF | PPCEF |
Probabilistically Plausible CF with normalizing flows |
| CEM | CEM_CF |
Contrastive Explanation Method |
| CEGP | CEGP |
Counterfactual with Gaussian Processes |
| CADEX | CADEX |
Counterfactual explanations via optimization |
| SACE | SACE |
Several SACE variants |
| CEARM | CEARM |
Counterfactual explanation through association rule mining |
Global Methods (Model-level)
| Method | Class | Description |
|---|---|---|
| GLOBE-CE | GLOBE_CE |
Global Counterfactual Explanations |
| AReS | AReS |
Actionable Recourse Summaries |
Group Methods (Cohort-level)
| Method | Class | Description |
|---|---|---|
| GLANCE | GLANCE |
Group-level CF method |
| T-CREx | TCREx |
Temporal Counterfactual Rule Extraction |
Datasets
The library includes 18 pre-configured datasets:
Classification (13):
adult_census, audit, bank_marketing, blobs, credit_default, digits, german_credit, give_me_some_credit (GMC), heloc, law, lending_club, moons, wine
Regression (5):
concrete, diabetes, yacht, synthetic, scm20d
Dataset configurations are in config/datasets/*.yaml and support:
- Automatic feature type detection (continuous/categorical)
- Actionability flags for features
- Cross-validation splits
- Train/test split configuration
Models
Discriminative Models
| Model | Class | Use Case |
|---|---|---|
| MLP Classifier | MLPClassifier |
General classification |
| Logistic Regression | LogisticRegression |
Binary classification |
| Multinomial LR | MultinomialLogisticRegression |
Multiclass |
| NODE | NODE |
Neural Oblivious Decision Ensembles |
Generative Models
| Model | Class | Description |
|---|---|---|
| MAF | MaskedAutoregressiveFlow |
Primary normalizing flow |
| RealNVP | RealNVP |
Real-valued Non-Volume Preserving |
| NICE | NICE |
Non-linear Independent Components |
| KDE | KDE |
Kernel Density Estimation baseline |
Regression Models
| Model | Class |
|---|---|
| MLP Regressor | MLPRegressor |
| Linear Regression | LinearRegression |
Metrics
The library provides comprehensive evaluation metrics:
| Category | Metrics |
|---|---|
| Validity | coverage, validity, actionability |
| Sparsity | sparsity |
| Distance | proximity_continuous_euclidean, proximity_continuous_manhattan, proximity_continuous_mad, proximity_categorical_hamming, proximity_categorical_jaccard, proximity_l2_jaccard, proximity_mad_hamming |
| Plausibility | prob_plausibility, log_density_cf, log_density_test |
| Outlier Detection | lof_scores_cf, lof_scores_test, isolation_forest_scores_cf, isolation_forest_scores_test |
Running Experiments
Using Hydra Pipelines
# Run PPCEF pipeline
uv run python cel/pipelines/run_ppcef_pipeline.py
# With custom configuration
uv run python cel/pipelines/run_ppcef_pipeline.py \
dataset.config_path=config/datasets/heloc.yaml \
disc_model.model=disc_model/mlp_large \
counterfactuals_params.target_class=1
Available Pipelines
| Pipeline | Method |
|---|---|
run_ppcef_pipeline.py |
PPCEF |
run_dice_pipeline.py |
DiCE |
run_cem_pipeline.py |
CEM |
run_cchvae_pipeline.py |
C-CHVAE |
run_wach_pipeline.py |
WACH |
run_artelt_pipeline.py |
Artelt |
run_cegp_pipeline.py |
CEGP |
run_cadex_pipeline.py |
CADEX |
run_sace_pipeline.py |
SACE |
run_cearm_pipeline.py |
CEARM |
run_globe_ce_pipeline.py |
GLOBE-CE |
run_ares_pipeline.py |
AReS |
run_glance_pipeline.py |
GLANCE |
run_tcrex_pipeline.py |
T-CREx |
Documentation
Live Docs: https://ofurman.github.io/counterfactuals/
Contributing
Contributions are welcome! Before opening a PR:
- Read
AGENTS.mdanddocs/ppcef_pipeline.mdto understand the workflow - Use
uvfor all operations:uv sync # Install dependencies uv run ruff check --fix # Lint and fix uv run pytest # Run tests
- Follow the coding standards:
- Python 3.10+, PEP 8 compliant
- Full type hints everywhere
- Google-style docstrings
- Line length: 100 characters
- Keep patches small and well-documented
- Update or add tests when behavior changes
To add new dependencies:
uv add <package>
Contact
For questions or comments, please contact via LinkedIn: TBA
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ce_library-0.1.1.tar.gz.
File metadata
- Download URL: ce_library-0.1.1.tar.gz
- Upload date:
- Size: 174.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
43cf367a2931c4dde0eb6256e1f0a65da019028963eac00c9ca88eae0d2013e3
|
|
| MD5 |
56546b298d806d030ef29a308ab889f9
|
|
| BLAKE2b-256 |
8ba080f39363763cf5566f133652f89c6338707f639d6c8f78ca1d13e88c2fb1
|
File details
Details for the file ce_library-0.1.1-py3-none-any.whl.
File metadata
- Download URL: ce_library-0.1.1-py3-none-any.whl
- Upload date:
- Size: 281.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
07c304b9bf20f0119fccea2b0f7159a87f36f5b42e0a8ad56afc03c26f5dbb87
|
|
| MD5 |
f8a59ade2ba76be575f2fa00f48fed05
|
|
| BLAKE2b-256 |
10e5e516667b4b15f25c7483a19780d30f3a1774343452a21f630b96487416d1
|