Skip to main content

A toolkit for developing group-aware ML methods

Project description

logo

Installation

This library requires at least Python 3.12. Install it from pypi:

pip install fair-forge

or from GitHub:

pip install git+https://github.com/wearepal/fair-forge.git

If you want to use the neural-network-based methods, you need to add the nn extras:

pip install 'fair-forge[nn]'

or

pip install 'fair-forge[nn] @ git+https://github.com/wearepal/fair-forge.git'

Usage

fair-forge provides two main components: metrics and methods. Besides these, there are various utility functions to help with common tasks and also a few example datasets.

The core data type used in forge-fair is numpy arrays: all the methods and metrics expect numpy arrays as input. If you have data in a different form, it is usually easy to convert it to numpy arrays:

Metrics

There are group-aware metrics and non-group-aware metrics. The non-group-aware metrics are callables with this function signature:

import numpy as np
from numpy.typing import NDArray

type Float = float | np.float16 | np.float32 | np.float64

def tpr(y_true: NDArray[np.int32], y_pred: NDArray[np.int32]) -> Float: ...

In other words, a non-group-aware metric accepts two numpy arrays — one with the true labels and one with the predicted labels — and returns a single Float. The API of the non-group-aware metrics is chosen such that any metric from scikit-learn can be used — for example, accuracy.

Group-aware metrics take an additional parameter, the group labels:

def cv(
    y_true: NDArray[np.int32],
    y_pred: NDArray[np.int32],
    *,
    groups: NDArray[np.int32],
) -> Float:

A very important function is fair_forge.as_group_metric(). It takes in a non-group-aware metric, and turns it into one or more group-aware metrics. This is done by first computing the metric value per group, and these individual metric values are then aggregated in different ways — for example, by taking the minimum or the ratio of the values. Here is how one would construct a robust accuracy metric (minimum accuracy across all groups):

import fair_forge as ff
from sklearn.metrics import accuracy_score

# Construct a metric for the minimum accuracy over all groups
(robust_accuracy,) = ff.as_group_metric(
    (accuracy_score,), agg=ff.MetricAgg.MIN
)

# Use it as a group-aware metric
robust_accuracy(y_true=y_true, y_pred=y_pred, groups=groups)

Methods

The group-aware vs non-group-aware distinction also exists for the methods provided in this library. The non-group-aware methods simply follow the scikit-learn API for an estimator (inheriting from BaseEstimator adds some mixin methods which are needed):

from sklearn.base import BaseEstimator

class Method(BaseEstimator):
    def fit(self, X: NDArray[np.float32], y: NDArray[np.int32]) -> Self:
        pass

    def predict(self, X: NDArray[np.float32]) -> NDArray[np.int32]:
        pass

The methods can be used like normal scikit-learn estimators.

On the other hand, we have the group-based methods, which take an additional parameter, the group labels:

from sklearn.base import BaseEstimator

class GroupMethod(BaseEstimator):
    def fit(self, X: NDArray[np.float32], y: NDArray[np.int32], *, group: NDArray[np.int32]) -> Self:
        pass

    def predict(self, X: NDArray[np.float32]) -> NDArray[np.int32]:
        pass

These methods can use the group information at training time to produce fairer models.

Besides methods which output a machine learning model, there are also methods which transform the data. These then have a transform method instead of a predict method:

from sklearn.base import BaseEstimator

class GroupBasedTransform(BaseEstimator):
    def fit(
        self, X: NDArray[np.float32], y: NDArray[np.int32], *, groups: NDArray[np.int32]
    ) -> Self:
        pass

    def transform(self, X: NDArray[np.float32]) -> NDArray[np.float32]:
        pass

    def fit_transform(
        self, X: NDArray[np.float32], y: NDArray[np.int32], *, groups: NDArray[np.int32]
    ) -> NDArray[np.float32]:
        pass

(Unfortunately, you have to implement fit_transform manually, because otherwise it will not have the groups parameter.)

Such transformation methods can then be combined with non-group-aware methods with scikit-learn’s Pipeline:

from sklearn import config_context
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC

# Pipeline will only forward the `groups` argument if we
# set `enable_metadata_routing` to `True`.
with config_context(enable_metadata_routing=True):
    estimator = LinearSVC(random_state=42, max_iter=100)
    transform = GroupBasedTransform(random_state=42)
	# We need to explicitly request here that the transformation's
	# `fit` function gets the `groups` argument.
    transform.set_fit_request(groups=True)

    pipeline = Pipeline([("transform", transform), ("estimator", estimator)])

	# This will call `fit_and_transform` on the Transformation
    pipeline.fit(train_x, train_y, groups=train_groups)
    preds = pipeline.predict(test_x)

Utilities

fair-forge provides many useful components for running experiments and collecting results:

  • example datasets (like Adult)
  • train-test splitting
  • facilities for running multiple methods and evaluating them with multiple metrics

For more information on this, see the documentation.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fair_forge-0.2.0.tar.gz (402.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fair_forge-0.2.0-py3-none-any.whl (405.9 kB view details)

Uploaded Python 3

File details

Details for the file fair_forge-0.2.0.tar.gz.

File metadata

  • Download URL: fair_forge-0.2.0.tar.gz
  • Upload date:
  • Size: 402.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.5

File hashes

Hashes for fair_forge-0.2.0.tar.gz
Algorithm Hash digest
SHA256 39a68c1ab6184b0d861138a37a3ef6add39672ad3843787ed2b3d8760027bb7b
MD5 15b2ef301abfe87e545ee508617474f4
BLAKE2b-256 fb798679ada816e5c9f154b97504cfcd14b84336114120d2630a753dcb71285d

See more details on using hashes here.

File details

Details for the file fair_forge-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: fair_forge-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 405.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.5

File hashes

Hashes for fair_forge-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8973763928aa084352219889fa6a763876d924eab0e0243f447f87858bc712b2
MD5 ce07a51d620d62105a30b70493c5b313
BLAKE2b-256 a8cb28316765683bfda80da6c99171eaa7dc74fcf6f8193a89f44d9c28281714

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page