Skip to main content

Fairness-aware decision tree and random forest classifiers

Project description

fair-trees

This package is the implementation of the paper "Fair tree classifier using strong demographic parity" (Pereira Barata et al., Machine Learning, 2023).

It provides fairness-aware decision tree and random forest classifiers built on a modified scikit-learn tree engine. The splitting criterion jointly optimises predictive performance and statistical parity with respect to one or more sensitive (protected) attributes Z.

Installation

pip install fair-trees

Or from source:

pip install -e . --no-build-isolation

Quick start

import numpy as np
import pandas as pd
from fair_trees import FairDecisionTreeClassifier, FairRandomForestClassifier, load_datasets

datasets = load_datasets()
data = datasets["bank_marketing"]   # "adult" is also available

# Preprocessing — the bundled data contains raw DataFrames
X = pd.get_dummies(data["X"]).values.astype(np.float64)
y = pd.factorize(data["y"].iloc[:, 0])[0]
Z = np.column_stack([pd.factorize(data["Z"][col])[0] for col in data["Z"].columns])

Fairness-aware decision tree

clf = FairDecisionTreeClassifier(
    theta=0.3,          # trade-off: 0 = pure accuracy, 1 = pure fairness
    Z_agg="max",        # how to aggregate across sensitive attributes / classes
    max_depth=5,
)
clf.fit(X, y, Z=Z)

y_prob = clf.predict_proba(X)[:, 1]

Fairness-aware random forest

rf = FairRandomForestClassifier(
    n_estimators=100,
    theta=0.3,
    Z_agg="max",
    max_depth=5,
    random_state=42,
)
rf.fit(X, y, Z=Z)

y_prob = rf.predict_proba(X)[:, 1]

Evaluation

The package does not ship its own metric functions, but the two scores that matter—ROC-AUC (predictive quality) and SDP (statistical parity)—can be computed from scipy in a few lines.

ROC-AUC via the Mann–Whitney U statistic

from scipy.stats import mannwhitneyu

roc_auc = mannwhitneyu(
    y_prob[y == 1],
    y_prob[y == 0],
).statistic / (sum(y == 1) * sum(y == 0))

print(f"ROC-AUC: {roc_auc:.4f}")

Statistical Disparity (SDP) score

SDP measures how well the model's predictions are separated by a protected attribute. It is defined as:

SDP = 1 − |AUC_Z − 0.5| × 2

where AUC_Z is computed the same way as ROC-AUC but treating each sensitive attribute/class in Z as the positive label.

When Z contains multiple columns (attributes) and/or more than two classes per attribute, the per-group AUC values must be aggregated. The Z_agg parameter controls this—matching the logic used inside the splitting criterion:

Z_agg Behaviour
"mean" Average the per-group SDP scores (across classes within an attribute, then across attributes).
"max" Take the worst-case (lowest) SDP across all groups—i.e. the group with the highest disparity dominates.
import numpy as np
from scipy.stats import mannwhitneyu


def sdp_score(y_prob, Z, Z_agg="max"):
    """Compute the Statistical Disparity (SDP) score.

    Parameters
    ----------
    y_prob : array-like of shape (n_samples,)
        Predicted probabilities for the positive class.
    Z : array-like of shape (n_samples,) or (n_samples, n_attributes)
        Sensitive / protected attribute(s).  Each column is treated as a
        separate attribute; each unique value within a column is a class.
    Z_agg : {"mean", "max"}, default="max"
        Aggregation method across attributes and classes.
        - "mean": average SDP across all groups.
        - "max":  return the worst-case (lowest) SDP.

    Returns
    -------
    float
        SDP in [0, 1].  1 = perfect parity, 0 = maximum disparity.
    """
    Z = np.atleast_2d(np.asarray(Z).T).T          # ensure (n_samples, n_attr)
    y_prob = np.asarray(y_prob)
    sdp_values = []

    for attr_idx in range(Z.shape[1]):
        z_col = Z[:, attr_idx]
        classes = np.unique(z_col)
        attr_sdps = []

        for cls in classes:
            mask_pos = z_col == cls
            mask_neg = ~mask_pos
            if mask_pos.sum() == 0 or mask_neg.sum() == 0:
                continue
            auc_z = mannwhitneyu(
                y_prob[mask_pos],
                y_prob[mask_neg],
            ).statistic / (mask_pos.sum() * mask_neg.sum())
            attr_sdps.append(1 - abs(auc_z - 0.5) * 2)

        if not attr_sdps:
            continue

        if Z_agg == "mean":
            sdp_values.append(np.mean(attr_sdps))
        else:  # "max" → worst case = minimum SDP
            sdp_values.append(np.min(attr_sdps))

    if not sdp_values:
        return 1.0  # no disparity measurable

    if Z_agg == "mean":
        return float(np.mean(sdp_values))
    else:
        return float(np.min(sdp_values))

Putting it all together

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import mannwhitneyu
from fair_trees import FairRandomForestClassifier, load_datasets

# Load and preprocess
datasets = load_datasets()
data = datasets["bank_marketing"]

X = pd.get_dummies(data["X"]).values.astype(np.float64)
y = pd.factorize(data["y"].iloc[:, 0])[0]
Z = np.column_stack([pd.factorize(data["Z"][col])[0] for col in data["Z"].columns])

# Sweep over theta values
thetas = [0, 0.2, 0.4, 0.6, 0.8, 1.0]
aucs, sdps = [], []

for theta in thetas:
    rf = FairRandomForestClassifier(
        n_estimators=100, theta=theta, Z_agg="max", max_depth=5, random_state=42,
    )
    rf.fit(X, y, Z=Z)
    y_prob = rf.predict_proba(X)[:, 1]

    roc_auc = mannwhitneyu(
        y_prob[y == 1], y_prob[y == 0],
    ).statistic / (sum(y == 1) * sum(y == 0))

    sdp = sdp_score(y_prob, Z, Z_agg="max")

    aucs.append(roc_auc)
    sdps.append(sdp)
    print(f"theta={theta:.1f}  ROC-AUC={roc_auc:.4f}  SDP={sdp:.4f}")

# Plot
fig, (ax1, ax3) = plt.subplots(1, 2, figsize=(14, 5))

# Left — Metrics vs. theta (dual axis)
ax1.set_xlabel("theta")
ax1.set_ylabel("ROC-AUC", color="tab:blue")
ax1.plot(thetas, aucs, "o-", color="tab:blue", label="ROC-AUC")
ax1.tick_params(axis="y", labelcolor="tab:blue")

ax2 = ax1.twinx()
ax2.set_ylabel("SDP", color="tab:orange")
ax2.plot(thetas, sdps, "s--", color="tab:orange", label="SDP")
ax2.tick_params(axis="y", labelcolor="tab:orange")

ax1.set_title("Metrics vs. theta")
lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc="lower center")

# Right — ROC-AUC vs. SDP frontier
ax3.plot(sdps, aucs, "o-", color="tab:green")
for i, theta in enumerate(thetas):
    ax3.annotate(f"θ={theta}", (sdps[i], aucs[i]), textcoords="offset points",
                 xytext=(8, 4), fontsize=9)
ax3.set_xlabel("SDP (fairness →)")
ax3.set_ylabel("ROC-AUC (performance →)")
ax3.set_title("Performance–Fairness Frontier")

fig.suptitle("Performance-Fairness Trade-off", fontsize=14)
fig.tight_layout()
plt.savefig("tradeoff.png", dpi=150, bbox_inches="tight")
plt.show()

Performance-Fairness Trade-off

Key parameters

Parameter Default Description
theta 0.0 Trade-off weight in [0, 1]. 0 = standard (unfair) tree; 1 = splits optimise only for fairness.
Z_agg "max" Aggregation over sensitive groups: "mean" (average) or "max" (worst-case).
Z None Sensitive attributes, passed to .fit(). Array of shape (n_samples,) or (n_samples, n_attributes).

All other parameters (max_depth, min_samples_split, n_estimators, etc.) behave identically to their scikit-learn counterparts.

Citation

If you use this software, please cite the paper:

Pereira Barata, A., Takes, F.W., van den Herik, H.J., & Veenman, C. (2023). Fair tree classifier using strong demographic parity. Machine Learning. doi:10.1007/s10994-023-06376-z

See CITATION.cff for a machine-readable citation file.

License

BSD-3-Clause

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fair_trees-3.1.7-cp312-cp312-win_amd64.whl (2.4 MB view details)

Uploaded CPython 3.12Windows x86-64

File details

Details for the file fair_trees-3.1.7-cp312-cp312-win_amd64.whl.

File metadata

  • Download URL: fair_trees-3.1.7-cp312-cp312-win_amd64.whl
  • Upload date:
  • Size: 2.4 MB
  • Tags: CPython 3.12, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.8

File hashes

Hashes for fair_trees-3.1.7-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 e6b7801f4e182391ee7d2e2da7ac4dcb1070da86188482b31e6dcf9e90f86420
MD5 9f6b47d16cb32cee6ef2c0526cadd265
BLAKE2b-256 30032a4d5ae9f96047e941b3054883e61122042aba4ec941dceda59a1dc489cb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page