Skip to main content

Adaptive Hierarchical Shrinkage

Project description

Scikit-Learn-compatible implementation of Adaptive Hierarchical Shrinkage

This directory contains an implementation of Adaptive Hierarchical Shrinkage that is compatible with Scikit-Learn. It exports 2 classes:

  • ShrinkageClassifier
  • ShrinkageRegressor

Installation

adhs Package

The adhs package, which contains the implementations of Adaptive Hierarchical Shrinkage, can be installed from PyPI:

$ pip install adhs

Alternatively, the package can be installed by cloning:

$ git clone git@github.com:arnegevaert/adaptive-hierarchical-shrinkage.git
$ cd adaptive-hierarchical-shrinkage
$ pip install .

Note that this will install the version corresponding to the latest commit on the master branch, which may or may not be stable.

Experiments

To be able to run the scripts in the experiments directory, some extra requirements are needed. These can be installed in a new conda environment as follows:

$ git clone git@github.com:arnegevaert/adaptive-hierarchical-shrinkage.git
$ cd adaptive-hierarchical-shrinkage
$ conda create -n shrinkage python=3.10
$ conda activate shrinkage
$ pip install .[experiments]

For more info on reproducing the experiments, see experiments/README.md.

Basic API

This package exports 2 classes and 1 method:

  • ShrinkageClassifier
  • ShrinkageRegressor
  • cross_val_shrinkage

ShrinkageClassifier and ShrinkageRegressor

Both classes inherit from ShrinkageEstimator, which extends sklearn.base.BaseEstimator. Adaptive hierarchical shrinkage can be summarized as follows:

\hat{f}(\mathbf{x}) = \mathbb{E}_{t_0}[y] + \sum_{l=1}^L\frac{\mathbb{E}_{t_l}[y] - \mathbb{E}_{t_{l-1}}[y]}{1 + \frac{g(t_{l-1})}{N(t_{l-1})}}

where $g(t_{l-1})$ is some function of the node $t_{l-1}$. Classical hierarchical shrinkage (Agarwal et al. 2022) corresponds to $g(t_{l-1}) = \lambda$, where $\lambda$ is a chosen constant.

  • __init__() parameters:
    • base_estimator: the estimator around which we "wrap" hierarchical shrinkage. This should be a tree-based estimator: DecisionTreeClassifier, RandomForestClassifier, ... (analogous for Regressors)
    • lmb: $\lambda$ hyperparameter
    • random_state: random state for reproducibility
    • shrink_mode: 6 options:
      • "no_shrinkage": dummy value. This setting will not influence the base_estimator in any way, and is equivalent to just using the base_estimator by itself. Added for easy comparison between different modes of shrinkage and no shrinkage at all.
      • "hs": classical Hierarchical Shrinkage (from Agarwal et al. 2022): $g(t_{l-1}) = \lambda$.
      • "hs_entropy": Adaptive Hierarchical Shrinkage with added entropy term: $g(t_{l-1}) = \lambda H(t_{l-1})$.
      • "hs_log_cardinality": Adaptive Hierarchical Shrinkage with log of cardinality term: $g(t_{l-1}) = \lambda \log C(t_{l-1})$ where $C(t)$ is the number of unique values in $t$.
      • "hs_global_permutation": Same as "hs_permutation", but the data is permuted only once for the full dataset rather than once in each node.
      • "hs_permutation": Adaptive Hierarchical Shrinkage with:
\begin{aligned}
g(t_{l-1}) &= \frac{1}{\alpha(t_{l-1})}\\
\alpha(t_{l-1}) &= 1 - \frac{\Delta_\mathcal{I}(t_{l-1}, { }_\pi x(t_{l-1})) + \epsilon}{\Delta_\mathcal{I}(t_{l-1}, x(t_{l-1}))+ \epsilon}
\end{aligned}
  • reshrink(shrink_mode, lmb, X): changes the shrinkage mode and/or lambda value in the shrinkage process. Calling reshrink with a given value of shrink_mode and/or lmb on an existing model is equivalent to fitting a new model with the same base estimator but the new, given values for shrink_mode and/or lmb. This method can avoid redundant computations in the shrinkage process, so can be more efficient than re-fitting a new ShrinkageClassifier or ShrinkageRegressor.
  • Other functions: fit(X, y), predict(X), predict_proba(X), score(X, y) work just like with any other sklearn estimator.

cross_val_shrinkage

This method can be used to efficiently run cross-validation for the shrink_mode and/or lmb hyperparameters. As adaptive hierarchical shrinkage is a fully post-hoc procedure, cross-validation requires no retraining of the base model. This function exploits this property.

Tutorials

  • General usage: Shows how to apply hierarchical shrinkage on a simple dataset and access feature importances.
  • Cross-validating shrinkage parameters: Hyperparameters for (augmented) hierarchical shrinkage (i.e. shrink_mode and lmb) can be tuned using cross-validation, without having to retrain the underlying model. This is because (augmented) hierarchical shrinkage is a fully post-hoc procedure. As the ShrinkageClassifier and ShrinkageRegressor are valid scikit-learn estimators, you could simply tune these hyperparameters using GridSearchCV as you would do with any other scikit-learn model. However, this will retrain the decision tree or random forest, which leads to unnecessary performance loss. This notebook shows how you can use our cross-validation function to cross-validate shrink_mode and lmb without this performance loss.

Experiments

To reproduce the experiments from the paper, see experiments/README.md.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

adhs-0.1.4.tar.gz (12.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

adhs-0.1.4-py3-none-any.whl (11.8 kB view details)

Uploaded Python 3

File details

Details for the file adhs-0.1.4.tar.gz.

File metadata

  • Download URL: adhs-0.1.4.tar.gz
  • Upload date:
  • Size: 12.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.30 {"installer":{"name":"uv","version":"0.9.30","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"NixOS","version":"25.11","id":"xantusia","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for adhs-0.1.4.tar.gz
Algorithm Hash digest
SHA256 a59115945d55723966231a8090dfc00c8d511ad722a438cfa26568816eea5ea5
MD5 9979b61c35f04eb860f646295a5038b6
BLAKE2b-256 47934827eff034251ecb360d48ef7d17b6baa98b092874257e9732d30fa45f2e

See more details on using hashes here.

File details

Details for the file adhs-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: adhs-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 11.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.30 {"installer":{"name":"uv","version":"0.9.30","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"NixOS","version":"25.11","id":"xantusia","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for adhs-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 524a237f5e9b2e0d57c37a08d22455a35e3b95345a22cc6c303c903a6aacb379
MD5 828fc993ac9161b0f5fd9e7fa3aed6f7
BLAKE2b-256 443279b3dae137ebaea18be1d714fa2fb13d2ad4f33776cca5a578a7791b2ddd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page