A counterfactual explanation library using Jax
Project description
Welcome to cfnet
Key Features
- fast: code runs significantly faster than existing CF explanation libraries.
- scalable: code can be accelerated over CPU, GPU, and TPU
- flexible: we provide flexible API for researchers to allow full customization.
TODO: - implement various methods of CF explanations
Install
cfnet
is built on top of
Jax. It also uses
Pytorch to load data.
Running on CPU
If you only need to run cfnet
on CPU, you can simply install via pip
or clone the GitHub
project.
Installation via PyPI:
pip install cfnet
Editable Install:
git clone https://github.com/BirkhoffG/cfnet.git
pip install -e cfnet
Running on GPU or TPU
If you wish to run cfnet
on GPU or TPU, please first install this
library via pip install cfnet
.
Then, you should install the right GPU or TPU version of Jax by following steps in the install guidelines.
A Minimum Example
from cfnet.utils import load_json
from cfnet.datasets import TabularDataModule
from cfnet.training_module import PredictiveTrainingModule
from cfnet.train import train_model
from cfnet.methods import VanillaCF
from cfnet.evaluate import generate_cf_results_local_exp, benchmark_cfs
from cfnet.import_essentials import *
data_configs = {
"data_dir": "assets/data/s_adult.csv",
"data_name": "adult",
"batch_size": 256,
"continous_cols": ["age","hours_per_week"],
"discret_cols": ["workclass","education","marital_status","occupation","race","gender"],
"imutable_cols": ["race","gender"]
}
m_configs = {
'lr': 0.003,
"sizes": [50, 10, 50],
"dropout_rate": 0.3
}
t_configs = {
'n_epochs': 10,
'monitor_metrics': 'val/val_loss',
'logger_name': 'pred',
'seed': 42,
"batch_size": 256
}
cf_configs = {
'n_steps': 1000,
'lr': 0.001
}
# load data
dm = TabularDataModule(data_configs)
# specify the ML model
training_module = PredictiveTrainingModule(m_configs)
# train ML model
params, opt_state = train_model(
training_module, dm, t_configs
)
# define CF Explanation Module
pred_fn = lambda x: training_module.forward(
params, random.PRNGKey(0), x, is_training=False)
cf_exp = VanillaCF(cf_configs)
# generate cf explanations
cf_results = generate_cf_results_local_exp(cf_exp, dm, pred_fn)
# benchmark different cf explanation methods
benchmark_cfs([cf_results])
/home/birk/mambaforge-pypy3/envs/nbdev2/lib/python3.7/site-packages/ipykernel_launcher.py:36: DeprecatedWarning: PredictiveTrainingModule is deprecated as of 0.0.7 and will be removed in 0.1.0. Use `cfnet.module.PredictiveTrainingModule` instead.
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
/home/birk/mambaforge-pypy3/envs/nbdev2/lib/python3.7/site-packages/haiku/_src/data_structures.py:144: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
leaves, treedef = jax.tree_flatten(tree)
/home/birk/mambaforge-pypy3/envs/nbdev2/lib/python3.7/site-packages/haiku/_src/data_structures.py:145: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
return jax.tree_unflatten(treedef, leaves)
/home/birk/code/cfnet/cfnet/_ckpt_manager.py:14: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
for x in jax.tree_leaves(state):
Epoch 9: 100%|██████████| 96/96 [00:01<00:00, 57.03batch/s, train/train_loss_1=0.0485]
100%|██████████| 1000/1000 [00:08<00:00, 124.53it/s]
<style scoped>
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
acc | validity | proximity | ||
---|---|---|---|---|
adult | VanillaCF | 0.826188 | 0.883675 | 7.05637 |
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
cfnet-0.0.10.tar.gz
(36.1 kB
view hashes)
Built Distribution
cfnet-0.0.10-py3-none-any.whl
(46.6 kB
view hashes)