Skip to main content

PyTorch-based package for constrained training of neural networks

Project description

humancompatible-train: a package for constrained machine learning

License Setup

The toolkit implements algorithms for constrained training of neural networks based on PyTorch, and inspired by PyTorch's API.

Table of Contents

  1. Basic installation instructions
  2. Using the toolkit
  3. Extending the toolkit
  4. Reproducing the Benchmark
  5. License and terms of use
  6. References

humancompatible-train is still under active development! If you find bugs or have feature requests, please file a Github issue.

Installation

Use

pip install humancompatible-train

The only dependencies of this package are numpy and torch.

Using the toolkit

The toolkit implements algorithms for constrained training of neural networks based on PyTorch.

The algorithms follow the dual_step() - step() framework: taking inspiration from PyTorch, the dual_step() does updates related to the dual parameters and prepares for the primal update (by, e.g., saving constraint gradients), and step() updates the primal parameters.

In general, your code using humancompatible-train would look something like this:

for inputs, labels in dataloader:
  # inference
  outputs = model(inputs)
  # calculate constraints and grads
  for constraint in constraints:
      c_eval = constraint(outputs, labels)
      c_eval.backwards(retain_grad=True)
      # depending on optimizer, update dual parameters / save constraint gradient / both
      optimizer.dual_step(c_eval)
      optimizer.zero_grad()
  # calculate objective
  loss = criterion(outputs,labels)
  loss.backwards()
  optimizer.step()
  optimizer.zero_grad()

Our idea is to

  1. Deviate minimally from the usual PyTorch workflow
  2. Make different stochastic-constrained stochastic optimization algorithms nearly interchangable in the code.

Code examples

You are invited to check out our new API presented in notebooks in the examples folder.

The example notebooks have additional dependencies, such as fairret. To install those, run

pip install humancompatible-train[examples]

The legacy API used for the benchmark is presented in examples/_old_/algorithm_demo.ipynb and examples/_old_/constraint_demo.ipynb.

Extending the toolkit

Adding new code

To add a new algorithm, you can subclass the PyTorch Optimizer class and proceed following the API guideline presented above.

Reproducing the Benchmark

The code used in our benchmark paper is not migrated to the new API yet (WIP).

Basic installation instructions

The code requires Python version 3.11.

  1. Create a virtual environment

bash (Linux)

python3.11 -m venv fairbenchenv
source fairbenchenv/bin/activate

cmd (Windows)

python -m venv fairbenchenv
fairbenchenv\Scripts\activate.bat
  1. Install from source.
git clone https://github.com/humancompatible/train.git
cd train
pip install -r requirements.txt
pip install .

If you wish to edit the code of the algorithms, install as an editable package:

pip install -e .

Warning: it is recommended to use Stochastic Ghost with the mkl-accelerated version of the scipy package with Stochastic Ghost; to install it, run

pip install --force-reinstall -i https://software.repos.intel.com/python/pypi scipy

after installing requirements.txt; otherwise, the algorithm will run slower. However, this is not supported on MacOS and may fail on some Windows devices.

Running the algorithms

The benchmark comprises the following algorithms:

  • Stochastic Ghost [2],
  • SSL-ALM [3],
  • Stochastic Switching Subgradient [4].

To reproduce the experiments of the paper, run the following:

cd experiments
python run_folktables.py data=folktables alg=sslalm
python run_folktables.py data=folktables alg=alm
python run_folktables.py data=folktables alg=ghost
python run_folktables.py data=folktables alg=ssg
python run_folktables.py data=folktables alg=sgd     # baseline, no fairness
python run_folktables.py data=folktables alg=fairret # baseline, fairness with regularizer

Each command will start 10 runs of the alg, 30 seconds each. The results will be saved to experiments/utils/saved_models and experiments/utils/exp_results.

This repository uses Hydra to manage parameters; see experiments/conf for configuration files.

  • To change the parameters of the experiment, such as the number of runs for each algorithm, run time, the dataset used (note: for now supports only Folktables) - use experiment.yaml.
  • To change the dataset settings - such as file location - or do dataset-specific adjustments - such as the configuration of the protected attributes - use data/{dataset_name}.yaml
  • To change algorithm hyperparameters, use alg/{algorithm_name}.yaml.
  • To change constraint hyperparameters, use constraint/{constraint_name}.yaml

Producing plots

The plots and tables like the ones in the paper can be produced using the two notebooks. experiments/algo_plots.ipynb houses the convergence plots, and experiments/model_plots.ipynb - all the others.

License and terms of use

humancompatible-train is provided under the Apache 2.0 Licence.

The benchmark part of the package relies on the Folktables package, provided under MIT Licence. It provides code to download data from the American Community Survey (ACS) Public Use Microdata Sample (PUMS) files managed by the US Census Bureau. The data itself is governed by the terms of use provided by the Census Bureau. For more information, see https://www.census.gov/data/developers/about/terms-of-service.html

Future work

  • Add more algorithms
  • Add more examples from different fields where constrained training of DNNs is employed
  • Migrate the benchmark to the new API

References

If you use this work, we encourage you to cite our paper,

@misc{kliachkin2025benchmarkingstochasticapproximationalgorithms,
      title={Benchmarking Stochastic Approximation Algorithms for Fairness-Constrained Training of Deep Neural Networks}, 
      author={Andrii Kliachkin and Jana Lepšová and Gilles Bareilles and Jakub Mareček},
      year={2025},
      eprint={2507.04033},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2507.04033}, 
}

[1] Ding, Hardt & Miller et al. (2021) Retiring Adult: New Datasets for Fair Machine Learning, Curran Associates, Inc..

[2] Facchinei & Kungurtsev (2023) Stochastic Approximation for Expectation Objective and Expectation Inequality-Constrained Nonconvex Optimization, arXiv.

[3] Huang, Zhang & Alacaoglu (2025) Stochastic Smoothed Primal-Dual Algorithms for Nonconvex Optimization with Linear Inequality Constraints, arXiv.

[4] Huang & Lin (2023) Oracle Complexity of Single-Loop Switching Subgradient Methods for Non-Smooth Weakly Convex Functional Constrained Optimization, Curran Associates Inc..

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

humancompatible_train-0.1.3.tar.gz (32.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

humancompatible_train-0.1.3-py3-none-any.whl (56.5 kB view details)

Uploaded Python 3

File details

Details for the file humancompatible_train-0.1.3.tar.gz.

File metadata

  • Download URL: humancompatible_train-0.1.3.tar.gz
  • Upload date:
  • Size: 32.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.11

File hashes

Hashes for humancompatible_train-0.1.3.tar.gz
Algorithm Hash digest
SHA256 776c330518926b93ff6255701945e7dd857dc7ab2d031b7d8fcdb6d64fa4c495
MD5 ade4bbade668c437ed590b1840a194a8
BLAKE2b-256 282f5612adf1d0190f25a4b4771d47b1c470e1c28bfcbf9e0324d67dbe777a86

See more details on using hashes here.

File details

Details for the file humancompatible_train-0.1.3-py3-none-any.whl.

File metadata

File hashes

Hashes for humancompatible_train-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 abff03bd57fe115a36e2e4068c857ab95e2eb1f8d79418973c8f8169d41dad10
MD5 25f9d92f591236a5129005dbd907420a
BLAKE2b-256 965ed25a19b510152deaa45cb929bf472bdf885b3e42eee9f0721bddc94fd5fd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page