Skip to main content

JAX-based Recourse Explanation Library

Project description

ReLax

Python CI status Docs pypi GitHub License

Overview | Installation | Tutorials | Documentation | Citing ReLax

Overview

ReLax (Recourse Explanation Library in Jax) is an efficient and scalable benchmarking library for recourse and counterfactual explanations, built on top of jax. By leveraging language primitives such as vectorization, parallelization, and just-in-time compilation in jax, ReLax offers massive speed improvements in generating individual (or local) explanations for predictions made by Machine Learning algorithms.

Some of the key features are as follows:

  • 🏃 Fast and scalable recourse generation.

  • 🚀 Accelerated over cpu, gpu, tpu.

  • 🪓 Comprehensive set of recourse methods implemented for benchmarking.

  • 👐 Customizable API to enable the building of entire modeling and interpretation pipelines for new recourse algorithms.

Installation

pip install jax-relax
# Or install the latest version of `jax-relax`
pip install git+https://github.com/BirkhoffG/jax-relax.git 

To futher unleash the power of accelerators (i.e., GPU/TPU), we suggest to first install this library via pip install jax-relax. Then, follow steps in the official install guidelines to install the right version for GPU or TPU.

Dive into ReLax

ReLax is a recourse explanation library for explaining (any) JAX-based ML models. We believe that it is important to give users flexibility to choose how to use ReLax. You can

  • only use methods implemeted in ReLax (as a recourse methods library);
  • build a pipeline using ReLax to define data module, training ML models, and generating CF explanation (for constructing recourse benchmarking pipeline).

ReLax as a Recourse Explanation Library

We introduce basic use cases of using methods in ReLax to generate recourse explanations. For more advanced usages of methods in ReLax, See this tutorials.

from relax.methods import VanillaCF
from relax import DataModule, MLModule, generate_cf_explanations, benchmark_cfs
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import functools as ft
import jax

Let’s first generate synthetic data:

xs, ys = make_classification(n_samples=1000, n_features=10, random_state=42)
train_xs, test_xs, train_ys, test_ys = train_test_split(xs, ys, random_state=42)

Next, we fit an MLP model for this data. Note that this model can be any model implmented in JAX. We will use the MLModule in ReLax as an example.

model = MLModule()
model.train((train_xs, train_ys), epochs=10, batch_size=64)

Generating recourse explanations are straightforward. We can simply call generate_cf of an implemented recourse method to generate one recourse explanation:

vcf = VanillaCF(config={'n_steps': 1000, 'lr': 0.05})
cf = vcf.generate_cf(test_xs[0], model.pred_fn)
assert cf.shape == test_xs[0].shape

Or generate a bunch of recourse explanations with jax.vmap:

generate_fn = ft.partial(vcf.generate_cf, pred_fn=model.pred_fn)
cfs = jax.vmap(generate_fn)(test_xs)
assert cfs.shape == test_xs.shape

ReLax for Building Recourse Explanation Pipelines

The above example illustrates the usage of the decoupled relax.methods to generate recourse explanations. However, users are required to write boilerplate code for tasks such as data preprocessing, model training, and generating recourse explanations with feature constraints.

ReLax additionally offers a one-liner framework, streamlining the process and helping users in building a standardized pipeline for generating recourse explanations. You can write three lines of code to benchmark recourse explanations:

data_module = DataModule.from_numpy(xs, ys)
exps = generate_cf_explanations(vcf, data_module, model.pred_fn)
benchmark_cfs([exps])

See Getting Started with ReLax for an end-to-end example of using ReLax.

Supported Recourse Methods

ReLax currently provides implementations of 9 recourse explanation methods.

Method Type Paper Title Ref
VanillaCF Non-Parametric Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR. [1]
DiverseCF Non-Parametric Explaining Machine Learning Classifiers through Diverse Counterfactual Explanations. [2]
ProtoCF Semi-Parametric Interpretable Counterfactual Explanations Guided by Prototypes. [3]
CounterNet Parametric CounterNet: End-to-End Training of Prediction Aware Counterfactual Explanations. [4]
GrowingSphere Non-Parametric Inverse Classification for Comparison-based Interpretability in Machine Learning. [5]
CCHVAE Semi-Parametric Learning Model-Agnostic Counterfactual Explanations for Tabular Data. [6]
VAECF Parametric Preserving Causal Constraints in Counterfactual Explanations for Machine Learning Classifiers. [7]
CLUE Semi-Parametric Getting a CLUE: A Method for Explaining Uncertainty Estimates. [8]
L2C Parametric Feature-based Learning for Diverse and Privacy-Preserving Counterfactual Explanations [9]

Citing ReLax

To cite this repository:

@software{relax2023github,
  author = {Hangzhi Guo and Xinchang Xiong and Amulya Yadav},
  title = {{R}e{L}ax: Recourse Explanation Library in Jax},
  url = {http://github.com/birkhoffg/jax-relax},
  version = {0.2.0},
  year = {2023},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax-relax-0.2.3.tar.gz (65.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_relax-0.2.3-py3-none-any.whl (79.9 kB view details)

Uploaded Python 3

File details

Details for the file jax-relax-0.2.3.tar.gz.

File metadata

  • Download URL: jax-relax-0.2.3.tar.gz
  • Upload date:
  • Size: 65.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.16

File hashes

Hashes for jax-relax-0.2.3.tar.gz
Algorithm Hash digest
SHA256 95f4fb509e9c2a9a61268a24c11abad076b07c0d14cce162f08cca1117dcc48a
MD5 241c702db93aabbb5ae2b2edd0526ba8
BLAKE2b-256 77d4d491be3812c3bedddc96d2fcd5ad6b4a22b838606be42d9d7b24ee55b00e

See more details on using hashes here.

File details

Details for the file jax_relax-0.2.3-py3-none-any.whl.

File metadata

  • Download URL: jax_relax-0.2.3-py3-none-any.whl
  • Upload date:
  • Size: 79.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.16

File hashes

Hashes for jax_relax-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 c713b2022df213ebad82eec18bc5e7a4a80a9f463741e642985f8de8b1150c96
MD5 f7b025ff77c3f28b046528453de3db62
BLAKE2b-256 6045aca6f6eeaa36e9adc77143c670159298f2ea4085671477479cc3c087a9ee

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page