Skip to main content

Prototype set models for supervised learning

Project description

Python package proset

proset copyright 2022-2024 by Nikolaus Ruf

Released under the MIT license - see LICENSE file for details

About

This package implements a supervised learning method we call the 'prototype set' or 'proset' algorithm.

The algorithm applies feature selection via an elastic net penalty [1] to a nonlinear distribution model. This uses locally weighted averaging similar to the extension of the Nadaraya-Watson estimator [2][3] to conditional distributions [4]. Instead of including a term for each training sample with unit weights, the algorithm selects a subset of representative samples (prototypes) with individual weights. Prototype selection is handled via random subsampling and controlled by a second elastic net penalty term.

Proset models are highly explainable due to their built-in feature selection and geometric properties:

  • Feature selection makes it easier for humans to review the model structure. If the number of relevant features is small, users can assess whether the choice is sensible and study low-dimensional representations like scatter plots or cuts through the decision space.
  • Prototype selection simplifies reviewing the model structure even if the number of features is large. We can perform weighted PCA on the feature matrix for the prototypes and use this to create low-dimensional maps of the data. Also, a check whether the training data has labeling errors or artifacts can start with the smaller set of prototypes.
  • The estimate for a particular sample can be explained by reviewing the prototypes with the highest impact. This is an explanation in terms of similar training instances instead of more abstract properties, which can help nontechnical users to understand and trust the model.
  • Proset rates new samples based on their absolute distance to the prototypes. That means the algorithm can detect whether a new sample is far away from the training data and the estimate should not be relied on.

A technical report describing the algorithm in detail can be found here:

> technical report (PDF)

The report includes a benchmark study covering hyperparameter selection, comparison to other estimators, and explanatory features.

Installation

Proset can be installed from PyPI via

pip install proset

This installs the package itself without the unit tests and benchmark scripts. If you are interested in these, please clone or download the full source code from GitHub:

> proset on GitHub

Dependencies

Proset requires Python 3.10 or later with the following packages:

  • matplotlib >= 3.8.3
  • numpy >= 1.26.4
  • pandas >= 2.2.1
  • scipy >= 1.12.0
  • scikit-learn >= 1.4.1.post1
  • statsmodels >= 0.14.1
  • xlsxwriter >= 3.1.9

Additional packages are required to run the benchmark scripts:

  • mnist >= 0.2.2
  • psutil >= 5.9.8
  • shap >= 0.44.0
  • xgboost >= 2.0.3

Package development relies on the following tools:

  • coverage >= 7.4.3
  • ipython >= 8.22.1
  • pylint >= 3.0.4
  • twine >= 5.0.0

To use tensorflow for model fitting, install

  • tensorflow >= 2.15.0

Use this command to install proset with all extras (no space allowed after comma):

pip install proset[benchmarks,development,tensorflow]

Usage

Proset implements an interface compatible with machine learning package scikit-learn [5]. You can create an estimator object like this:

from proset import ClassifierModel

model = ClassifierModel()

The model implements the fit(), predict(), predict_proba(), and score() methods required for scikit-learn estimators. It has three additional public methods export(), explain(), and shrink(). The first creates a report with model parameters, the second explains a particular prediction, and the last reduces the model to expect only the active features as input.

The utility submodule has helper functions for selecting hyperparameters and creating diagnostic reports and plots:

import proset.utility as utility

utility.select_hyperparameters(...)

To learn more about using proset, you can...

  • use Python's help() to print the docstring for each function.
  • refer to Chapter 5 'Implementation notes' of the technical report.
  • look at the scripts for the benchmark study, which can serve as a tutorial:

> benchmark scripts

Release history

  • version 0.1.0: implementation of proset classifier using algorithm L-BFGS-B [6] for parameter estimation; helper functions for model fitting and plotting; benchmark code for hyperparameter selection, comparison to other classifiers, and demonstration of explanatory features; first version of technical report.
  • version 0.2.0: measures for faster computation: reduce float arrays to 32-bit precision, make solver tolerance configurable, enable tensorflow [7] as alternative backend for model fitting; reduce memory consumption for scoring; new options for select_hyperparameters(): chunks (macro-batching to reduce memory consumption for training), cv_groups (group related samples during cross-validation); add benchmark cases with greater sample size and feature dimension.
  • version 0.2.1: bugfix: if sample weights are passed for training, these are also used to compute marginal class probabilities.
  • version 0.3.0: instead of splitting training data into chunks that fit in memory, model fitting now supports an upper bound on the number of samples per batch, which is more efficient.
  • version 0.3.1: benchmark scripts cleaned up.
  • version 0.4.0: modified the recommended fit strategy to reduce overfitting when using multiple batches.
  • version 0.5.0: modified the strategy for selecting candidates such that it can be extended to regression.
  • version 0.5.1: cleaned up minor issues related to package dependencies.
  • version 0.6.0: updated requirements to Python 3.10 and compatible packages; changed the definition of the alpha parameters to match the literature (large values indicate dominant L1 penalty).
  • version 0.6.1: minor updates to plot functions.

Note on performance

Version 0.2.0 improves compute performance as version 0.1.0 was somewhat unsatisfactory in that regard. The time for training a classifier has been improved by a factor ranging from over two to nine for five test cases. Also, to support processing larger data sets, tensorflow can be used as an alternative backend for training. The memory requirements for training and scoring have been considerably reduced.

Version 0.4.0 changes the algorithm for hyperparameter search used by utility.select_hyperparameters(). With default settings, the new algorithm fits models with a total number of batches around twice as large as before. This leads to an approximate doubling of the corresponding runtime. The upside is that the resulting models tend to have slightly better log-loss.

Contact

Please contact nikolaus.ruf@t-online.de for any questions or suggestions.

References

[1] H. Zou, T. Hastie: 'Regularization and variable selection via the elastic net', Journal of the Royal Statistical Society, Series B, vol. 37, part 2, pp. 301-320, 2005.

[2] E. A. Nadaraya: 'On Estimating Regression', Theory of Probability and Its Applications, vol. 9, issue 1, pp. 141-142, 1964.

[3] G. S. Watson: 'Smooth Regression Analysis', Sankhyā: The Indian Journal of Statistics, Series A, vol. 26, no. 4, pp. 359-372, 1964.

[4] P. Hall, J. Racine, Q. Li: 'Cross-validation and the Estimation of Conditional Probability Densities', Journal of the American Statistical Association, vol. 99, issue 468, pp. 1015-1026, 2004.

[5] F. Pedregosa et al.: 'Scikit-learn: Machine Learning in Python', JMLR 12, pp. 2825-2830, 2011.

[6] R. H. Byrd, P. Lu, J. Nocedal: 'A Limited Memory Algorithm for Bound Constrained Optimization', SIAM Journal on Scientific and Statistical Computing, vol. 16, issue 5, pp. 1190-1208, 1995.

[7] M. Abadi et al.: 'TensorFlow: Large-scale machine learning on heterogeneous systems', 2015.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

proset-0.6.1.tar.gz (77.1 kB view details)

Uploaded Source

Built Distribution

proset-0.6.1-py3-none-any.whl (86.4 kB view details)

Uploaded Python 3

File details

Details for the file proset-0.6.1.tar.gz.

File metadata

  • Download URL: proset-0.6.1.tar.gz
  • Upload date:
  • Size: 77.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.14

File hashes

Hashes for proset-0.6.1.tar.gz
Algorithm Hash digest
SHA256 c604d372d30689b8c6574491b05e79893db3d83fbfaf54f67e45966d19f7b321
MD5 994ceed9dceeb72429c46f3eeedcadf4
BLAKE2b-256 18fc3047f42493e0e4d3d2f4777db13702fe50da3102c2f016ea425be8e82b3e

See more details on using hashes here.

File details

Details for the file proset-0.6.1-py3-none-any.whl.

File metadata

  • Download URL: proset-0.6.1-py3-none-any.whl
  • Upload date:
  • Size: 86.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.14

File hashes

Hashes for proset-0.6.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a5d9fbd860cbb6551069cf72035829e653a7e16462e97c4d95e3a067e1461ef5
MD5 ac92b1c18bd16a3fb37b54e4c9d07976
BLAKE2b-256 9b4d0b5968f24605910408022d49d94b71a1d4e34d5007e6144fb3d973976487

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page