Skip to main content

Adaptive Hierarchical Shrinkage

Project description

Scikit-Learn-compatible implementation of Adaptive Hierarchical Shrinkage

This directory contains an implementation of Adaptive Hierarchical Shrinkage that is compatible with Scikit-Learn. It exports 2 classes:

  • ShrinkageClassifier
  • ShrinkageRegressor

Installation

adhs Package

The adhs package, which contains the implementations of Adaptive Hierarchical Shrinkage, can be installed using:

pip install .

Experiments

To be able to run the scripts in the experiments directory, some extra requirements are needed. These can be installed in a new conda environment as follows:

conda create -n shrinkage python=3.10
conda activate shrinkage
pip install .[experiments]

Basic API

This package exports 2 classes and 1 method:

  • ShrinkageClassifier
  • ShrinkageRegressor
  • cross_val_shrinkage

ShrinkageClassifier and ShrinkageRegressor

Both classes inherit from ShrinkageEstimator, which extends sklearn.base.BaseEstimator. Adaptive hierarchical shrinkage can be summarized as follows: $$ \hat{f}(\mathbf{x}) = \mathbb{E}{t_0}[y] + \sum{l=1}^L\frac{\mathbb{E}{t_l}[y] - \mathbb{E}{t_{l-1}}[y]}{1 + \frac{g(t_{l-1})}{N(t_{l-1})}} $$ where $g(t_{l-1})$ is some function of the node $t_{l-1}$. Classical hierarchical shrinkage (Agarwal et al. 2022) corresponds to $g(t_{l-1}) = \lambda$, where $\lambda$ is a chosen constant.

  • __init__() parameters:
    • base_estimator: the estimator around which we "wrap" hierarchical shrinkage. This should be a tree-based estimator: DecisionTreeClassifier, RandomForestClassifier, ... (analogous for Regressors)
    • shrink_mode: 6 options:
      • "no_shrinkage": dummy value. This setting will not influence the base_estimator in any way, and is equivalent to just using the base_estimator by itself. Added for easy comparison between different modes of shrinkage and no shrinkage at all.
      • "hs": classical Hierarchical Shrinkage (from Agarwal et al. 2022): $g(t_{l-1}) = \lambda$.
      • "hs_entropy": Adaptive Hierarchical Shrinkage with added entropy term: $g(t_{l-1}) = \lambda H(t_{l-1})$.
      • "hs_log_cardinality": Adaptive Hierarchical Shrinkage with log of cardinality term: $g(t_{l-1}) = \lambda \log C(t_{l-1})$ where $C(t)$ is the number of unique values in $t$.
      • "hs_permutation": Adaptive Hierarchical Shrinkage with $g(t_{l-1}) = \frac{1}{\alpha(t_{l-1})}$, with $\alpha(t_{l-1}) = 1 - \frac{\Delta_\mathcal{I}(t_{l-1}, { }\pi x(t{l-1})) + \epsilon}{\Delta_\mathcal{I}(t_{l-1}, x(t_{l-1}))+ \epsilon}$
      • "hs_global_permutation": Same as "hs_permutation", but the data is permuted only once for the full dataset rather than once in each node.
    • lmb: $\lambda$ hyperparameter
    • random_state: random state for reproducibility
  • reshrink(shrink_mode, lmb, X): changes the shrinkage mode and/or lambda value in the shrinkage process. Calling reshrink with a given value of shrink_mode and/or lmb on an existing model is equivalent to fitting a new model with the same base estimator but the new, given values for shrink_mode and/or lmb. This method can avoid redundant computations in the shrinkage process, so can be more efficient than re-fitting a new ShrinkageClassifier or ShrinkageRegressor.
  • Other functions: fit(X, y), predict(X), predict_proba(X), score(X, y) work just like with any other sklearn estimator.

cross_val_shrinkage

This method can be used to efficiently run cross-validation for the shrink_mode and/or lmb hyperparameters. As adaptive hierarchical shrinkage is a fully post-hoc procedure, cross-validation requires no retraining of the base model. This function exploits this property.

Tutorials

  • General usage: Shows how to apply hierarchical shrinkage on a simple dataset and access feature importances.
  • Cross-validating shrinkage parameters: Hyperparameters for (augmented) hierarchical shrinkage (i.e. shrink_mode and lmb) can be tuned using cross-validation, without having to retrain the underlying model. This is because (augmented) hierarchical shrinkage is a fully post-hoc procedure. As the ShrinkageClassifier and ShrinkageRegressor are valid scikit-learn estimators, you could simply tune these hyperparameters using GridSearchCV as you would do with any other scikit-learn model. However, this will retrain the decision tree or random forest, which leads to unnecessary performance loss. This notebook shows how you can use our cross-validation function to cross-validate shrink_mode and lmb without this performance loss.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

adhs-0.1.3.tar.gz (13.0 kB view details)

Uploaded Source

Built Distribution

adhs-0.1.3-py3-none-any.whl (12.0 kB view details)

Uploaded Python 3

File details

Details for the file adhs-0.1.3.tar.gz.

File metadata

  • Download URL: adhs-0.1.3.tar.gz
  • Upload date:
  • Size: 13.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for adhs-0.1.3.tar.gz
Algorithm Hash digest
SHA256 e8a5e252759c1efc7ff605e1dfd7f92b408d2968f1426686c692a614406fe750
MD5 c7643583bfbeaea9d0056109b83d9c2c
BLAKE2b-256 d11c6a3304e0959baa554500adea9463a06722eddae39d5cf85e8c38b4343e42

See more details on using hashes here.

File details

Details for the file adhs-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: adhs-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 12.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for adhs-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 33de81ac34e980dff8961cf0b94ab0a0037aa9d6462b778bb6885fc641ade2c6
MD5 fdbfabf15b59ae1fb2ce6c5397f0a54f
BLAKE2b-256 190d635d18c1009bc1336fd1616d526fcdee903bc2e00e0c48095323909cad43

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page