Skip to main content

A toolbox for fair and explainable machine learning

Project description

ethik

Table of contents

Introduction

ethik is a Python package for performing fair and explainable machine learning.

overview

Currently, ethik can be used for:

  1. Detecting model bias with respect to one or more (protected) attributes.
  2. Identifying causes for why the model performs poorly on certain inputs.
  3. Visualizing regions of an image influence a model's predictions.

Installation

:warning: Python 3.6 or above is required :snake:

Via GitHub for the latest development version

>>> pip install git+https://github.com/MaxHalford/ethik
>>> # Or through SSH:
>>> pip install git+ssh://git@github.com/MaxHalford/ethik.git

Development installation

>>> git clone https://github.com/MaxHalford/ethik
>>> cd ethik
>>> python setup.py develop
>>> pip install -r requirements-dev.txt
>>> pre-commit install # For black

User guide

:point_up: Please check out this notebook for more detailed code.

In the following example we'll be using the "Adult" dataset. This dataset contains a binary label indicating if a person's annual income is larger than $50k. ethik can diagnose a model by looking at the predictions the model makes on a test set. Consequently, you first have to split your dataset in two (train and test).

from sklearn import model_selection

X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, shuffle=True, random_state=42)

You then want to train your model on the training set and make predictions on the test set. In this example we'll train a gradient boosting classifier from the LightGBM library. We'll use a variable named y_pred to store the predicted probabilities associated with the True label.

import lightgbm as lgb

model = lgb.LGBMClassifier(random_state=42).fit(X_train, y_train)

# We use a named pandas series to make plot labels more explicit
y_pred = model.predict_proba(X_test)[:, 1]
y_pred = pd.Series(y_pred, name='>$50k')

We can now initialize an ClassificationExplainer using the default parameters.

import ethik

explainer = ethik.ClassificationExplainer()

Measuring model bias

ethik can be used to understand how the model predictions vary as a function of one or more features. For example we can look at how the model behaves with respect to the age feature.

explainer.plot_bias(X_test=X_test['age'], y_pred=y_pred)
Age bias

Recall that the target indicates if a person's annual salary is above $50k. We can see that the model predicts higher probabilities for older people. This isn't a surprising result, and could have just as well been observed by looking at the data. However, we can see that the predictions plateau at around 50 years old. Indeed, although salary is correlated with age, some people may retire early or lose their job. Furthermore we can see that the model understands the fact that salaries shrink once people get in age of retiring. This up-and-down relationship is in nature non-linear, and isn't picked up by summary statistics such as correlation coefficients, odds ratios, and feature importances in general. Although the observations we made are quite obvious and rather intuitive, it's always good to confirm what the model is thinking. The point is that the curves produced by plot_predictions represent the relationship between a variable and the target according to the model, rather than the data.

We can also plot the distribution of predictions for more than one variable. However, because different variables have different scales we have to use a common measure to display them together. For this purpose we plot the τ ("tau") values. These values are contained between -1 and 1 and simply reflect by how much the variable is shifted from it's mean towards it's lower and upper quantiles. In the following figure a tau value of -1 corresponds to just under 20 years old whereas a tau value of 1 refers to being slightly over 60 years old.

explainer.plot_bias(X_test=X_test['age', 'education-num'], y_pred=y_pred)
Age and education bias

We can observe that the model assigns higher probabilities to people with higher degrees, which makes perfect sense. Again, this conveys much more of a story than summary statistics.

Evaluating model reliability

Our methodology can also be used to evaluate the reliability of a model under different scenarios. Evaluation metrics that are commonly used in machine learning only tell you part of the story. Indeed they tell you the performance of a model on average. A more interesting approach is to visualize how accurate the model is with respect to the distribution of a variable.

explainer.plot_performance(
    X_test=X_test['age'],
    y_test=y_test,
    y_pred=y_pred > 0.5,  # metrics.accuracy_score requires y_pred to be binary
    metric=metrics.accuracy_score
)
Age accuracy

In the above figure we can see that the model is more reliable for younger people than for older ones. Having a fine-grained understanding of the accuracy of a model can be of extreme help in real-life scenarios. Moreover this can help you understand from where the error of the model is coming from and guide your data science process.

Similarly to the plot_predictions method, we can display the performance of the model for multiple variables.

explainer.plot_performance(
    X_test=X_test['age', 'education-num'],
    y_test=y_test,
    y_pred=y_pred > 0.5,
    metric=metrics.accuracy_score
)
Age and education accuracy

Support for image classification

A special class named ImageClassificationExplainer can be used to analyze image classification models. It has the same API as ClassificationExplainer, but expects to be provided with an array of images. For instance, we can analyze a CNN run on the MNIST dataset from the Keras documendation. The model achieves an accuracy of around 99% on the test set. For the sake of brevity we will the skip the exact details of the model training.

(x_train, y_train), (x_test, y_test) = mnist.load_data()

cnn.fit(x_train, y_train)
y_pred = cnn.predict_proba(x_test)

x_test is a set of images of shape (10000, 28, 28) whilst y_pred is a set of probabilities predicted for digit by the CNN, and is thus of shape (10000, 10). We can use the plot_bias method to display the importance of each pixel for the classifier with respect to each label.

import ethik

explainer = ethik.ImageClassificationExplainer()
explainer.plot_bias(x_test, y_pred)
Image bias explanation

This takes around 15 seconds to run on a mid-tier laptop. The previous plot highlights the regions of importance for identifying each digit. More precisely, the intensity of each pixel corresponds to the probability increase of saturating or not the pixel. A value of 0.28 means that saturating the pixel increases the probability predicted by the model by 0.28. Note that we do not saturate and desaturate the pixels independently. Instead, our method understands which pixels are linked together and saturates them in a realistic manner. The previous images show that the CNN seems to be using the same visual cues as a human. However, we can see that is uses very specific regions on images to identify particular digits. For instance, the top-right region of an image seems to trigger the "5" digit, whereas the bottom parts of the images seem to be linked with the "7" digit.

Authors

This work is led by members of the Toulouse Institute of Mathematics, namely:

This work is financed by the Centre National de la Recherche Scientifique (CNRS) and is done in the context of the Artificial and Natural Intelligence Toulouse Institute (ANITI) project.

License

This software is released under the GPL license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ethik-0.0.1.tar.gz (24.7 kB view details)

Uploaded Source

Built Distribution

ethik-0.0.1-py2.py3-none-any.whl (35.1 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file ethik-0.0.1.tar.gz.

File metadata

  • Download URL: ethik-0.0.1.tar.gz
  • Upload date:
  • Size: 24.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/40.8.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.3

File hashes

Hashes for ethik-0.0.1.tar.gz
Algorithm Hash digest
SHA256 bacdd478feb5cb2a508b2a5c3201ca49a644f025f5fb1a3b9584bed3b21fd35f
MD5 8f5e96d6451fa1256449af66777d8a81
BLAKE2b-256 8121d20709eabb648e2daab9e30d8172ab71d7c0a6be205452afb9a819a2db2d

See more details on using hashes here.

File details

Details for the file ethik-0.0.1-py2.py3-none-any.whl.

File metadata

  • Download URL: ethik-0.0.1-py2.py3-none-any.whl
  • Upload date:
  • Size: 35.1 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/40.8.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.3

File hashes

Hashes for ethik-0.0.1-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 b9b70c8f299c2f1486edc6758a46b2bef92e57e6d1bc960625b4a1e98b912693
MD5 6bc2bcf9e678a6da3f06f4733bd7f59b
BLAKE2b-256 1114cbf0d2179c31df085230b22175d0c9535b43660977842d03acfb1c83354e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page