Skip to main content

A counterfactual explanation library using Jax

Project description

Welcome to cfnet

Key Features

  • fast: code runs significantly faster than existing CF explanation libraries.
  • scalable: code can be accelerated over CPU, GPU, and TPU
  • flexible: we provide flexible API for researchers to allow full customization.

TODO: - implement various methods of CF explanations

Install

cfnet is built on top of Jax. It also uses Pytorch to load data.

Running on CPU

If you only need to run cfnet on CPU, you can simply install via pip or clone the GitHub project.

Installation via PyPI:

pip install cfnet

Editable Install:

git clone https://github.com/BirkhoffG/cfnet.git
pip install -e cfnet

Running on GPU or TPU

If you wish to run cfnet on GPU or TPU, please first install this library via pip install cfnet.

Then, you should install the right GPU or TPU version of Jax by following steps in the install guidelines.

A Minimum Example

from cfnet.utils import load_json
from cfnet.datasets import TabularDataModule
from cfnet.training_module import PredictiveTrainingModule
from cfnet.train import train_model
from cfnet.methods import VanillaCF
from cfnet.evaluate import generate_cf_results_local_exp, benchmark_cfs
from cfnet.import_essentials import *

data_configs = {
    "data_dir": "assets/data/s_adult.csv",
    "data_name": "adult",
    "batch_size": 256,
    "continous_cols": ["age","hours_per_week"],
    "discret_cols": ["workclass","education","marital_status","occupation","race","gender"],
    "imutable_cols": ["race","gender"]
}
m_configs = {
    'lr': 0.003,
    "sizes": [50, 10, 50],
    "dropout_rate": 0.3
}
t_configs = {
    'n_epochs': 10,
    'monitor_metrics': 'val/val_loss',
    'logger_name': 'pred',
    'seed': 42,
    "batch_size": 256
}
cf_configs = {
    'n_steps': 1000,
    'lr': 0.001
}

# load data
dm = TabularDataModule(data_configs)

# specify the ML model 
training_module = PredictiveTrainingModule(m_configs)

# train ML model
params, opt_state = train_model(
    training_module, dm, t_configs
)

# define CF Explanation Module
pred_fn = lambda x: training_module.forward(
    params, random.PRNGKey(0), x, is_training=False)
cf_exp = VanillaCF(cf_configs)

# generate cf explanations
cf_results = generate_cf_results_local_exp(cf_exp, dm, pred_fn)

# benchmark different cf explanation methods
benchmark_cfs([cf_results])
/home/birk/mambaforge-pypy3/envs/nbdev2/lib/python3.7/site-packages/ipykernel_launcher.py:36: DeprecatedWarning: PredictiveTrainingModule is deprecated as of 0.0.7 and will be removed in 0.1.0. Use `cfnet.module.PredictiveTrainingModule` instead.
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
/home/birk/mambaforge-pypy3/envs/nbdev2/lib/python3.7/site-packages/haiku/_src/data_structures.py:144: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
  leaves, treedef = jax.tree_flatten(tree)
/home/birk/mambaforge-pypy3/envs/nbdev2/lib/python3.7/site-packages/haiku/_src/data_structures.py:145: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
  return jax.tree_unflatten(treedef, leaves)
/home/birk/code/cfnet/cfnet/_ckpt_manager.py:14: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
  for x in jax.tree_leaves(state):
Epoch 9: 100%|██████████| 96/96 [00:01<00:00, 57.03batch/s, train/train_loss_1=0.0485]
100%|██████████| 1000/1000 [00:08<00:00, 124.53it/s]
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
acc validity proximity
adult VanillaCF 0.826188 0.883675 7.05637

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cfnet-0.0.9.tar.gz (31.5 kB view hashes)

Uploaded Source

Built Distribution

cfnet-0.0.9-py3-none-any.whl (38.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page