Skip to main content

counterfactual explanation using Jax

Project description

CFNET

A fast and scalable library for counterfactual explanations in Jax.

Key Features

  • fast: code runs significantly faster than existing CF explanation libraries.
  • scalable: code can be accelerated over CPU, GPU, and TPU
  • flexible: we provide flexible API for researchers to allow full customization.

TODO:

  • implement various methods of CF explanations

Install

cfnet is built on top of Jax. It also uses Pytorch to load data.

Running on CPU

If you only need to run cfnet on CPU, you can simply install via pip or clone the GitHub project.

Installation via PyPI:

pip install cfnet

Editable Install:

git clone https://github.com/BirkhoffG/cfnet.git
pip install -e cfnet

Running on GPU or TPU

If you wish to run cfnet on GPU or TPU, please first install this library via pip install cfnet.

Then, you should install the right GPU or TPU version of Jax by following steps in the install guidelines.

A Minimum Example

#hide_output
from cfnet.utils import load_json
from cfnet.datasets import TabularDataModule
from cfnet.training_module import PredictiveTrainingModule
from cfnet.train import train_model
from cfnet.methods import VanillaCF
from cfnet.evaluate import generate_cf_results_local_exp, benchmark_cfs
from cfnet.import_essentials import *

data_configs = load_json('assets/configs/data_configs/adult.json')
m_configs = {
    'lr': 0.003,
    "sizes": [50, 10, 50],
    "dropout_rate": 0.3
}
t_configs = {
    'n_epochs': 10,
    'monitor_metrics': 'val/val_loss',
    'logger_name': 'pred'
}
cf_configs = {
    'n_steps': 1000,
    'lr': 0.001
}

# load data
dm = TabularDataModule(data_configs)

# specify the ML model 
training_module = PredictiveTrainingModule(m_configs)

# train ML model
params, opt_state = train_model(
    training_module, dm, t_configs
)

# define CF Explanation Module
pred_fn = lambda x: training_module.forward(
    params, random.PRNGKey(0), x, is_training=False)
cf_exp = VanillaCF(pred_fn, cf_configs)

# generate cf explanations
cf_results = generate_cf_results_local_exp(cf_exp, dm)

# benchmark different cf explanation methods
benchmark_cfs([cf_results])

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cfnet-0.0.1.tar.gz (20.4 kB view hashes)

Uploaded Source

Built Distribution

cfnet-0.0.1-py3-none-any.whl (21.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page