Adaptive Hierarchical Shrinkage
Project description
Scikit-Learn-compatible implementation of Adaptive Hierarchical Shrinkage
This directory contains an implementation of Adaptive Hierarchical Shrinkage that is compatible with Scikit-Learn. It exports 2 classes:
ShrinkageClassifierShrinkageRegressor
Installation
adhs Package
The adhs package, which contains the implementations of
Adaptive Hierarchical Shrinkage, can be installed using:
pip install .
Experiments
To be able to run the scripts in the experiments directory, some extra
requirements are needed. These can be installed in a new conda
environment as follows:
conda create -n shrinkage python=3.10
conda activate shrinkage
pip install .[experiments]
Basic API
This package exports 2 classes and 1 method:
ShrinkageClassifierShrinkageRegressorcross_val_shrinkage
ShrinkageClassifier and ShrinkageRegressor
Both classes inherit from ShrinkageEstimator, which extends sklearn.base.BaseEstimator. Adaptive hierarchical shrinkage can be summarized as follows:
$$
\hat{f}(\mathbf{x}) = \mathbb{E}{t_0}[y] + \sum{l=1}^L\frac{\mathbb{E}{t_l}[y] - \mathbb{E}{t_{l-1}}[y]}{1 + \frac{g(t_{l-1})}{N(t_{l-1})}}
$$
where $g(t_{l-1})$ is some function of the node $t_{l-1}$. Classical hierarchical shrinkage (Agarwal et al. 2022) corresponds to $g(t_{l-1}) = \lambda$, where $\lambda$ is a chosen constant.
__init__()parameters:base_estimator: the estimator around which we "wrap" hierarchical shrinkage. This should be a tree-based estimator:DecisionTreeClassifier,RandomForestClassifier, ... (analogous forRegressors)shrink_mode: 6 options:"no_shrinkage": dummy value. This setting will not influence thebase_estimatorin any way, and is equivalent to just using thebase_estimatorby itself. Added for easy comparison between different modes of shrinkage and no shrinkage at all."hs": classical Hierarchical Shrinkage (from Agarwal et al. 2022): $g(t_{l-1}) = \lambda$."hs_entropy": Adaptive Hierarchical Shrinkage with added entropy term: $g(t_{l-1}) = \lambda H(t_{l-1})$."hs_log_cardinality": Adaptive Hierarchical Shrinkage with log of cardinality term: $g(t_{l-1}) = \lambda \log C(t_{l-1})$ where $C(t)$ is the number of unique values in $t$."hs_permutation": Adaptive Hierarchical Shrinkage with $g(t_{l-1}) = \frac{1}{\alpha(t_{l-1})}$, with $\alpha(t_{l-1}) = 1 - \frac{\Delta_\mathcal{I}(t_{l-1}, { }\pi x(t{l-1})) + \epsilon}{\Delta_\mathcal{I}(t_{l-1}, x(t_{l-1}))+ \epsilon}$"hs_global_permutation": Same as"hs_permutation", but the data is permuted only once for the full dataset rather than once in each node.
lmb: $\lambda$ hyperparameterrandom_state: random state for reproducibility
reshrink(shrink_mode, lmb, X): changes the shrinkage mode and/or lambda value in the shrinkage process. Callingreshrinkwith a given value ofshrink_modeand/orlmbon an existing model is equivalent to fitting a new model with the same base estimator but the new, given values forshrink_modeand/orlmb. This method can avoid redundant computations in the shrinkage process, so can be more efficient than re-fitting a newShrinkageClassifierorShrinkageRegressor.- Other functions:
fit(X, y),predict(X),predict_proba(X),score(X, y)work just like with any othersklearnestimator.
cross_val_shrinkage
This method can be used to efficiently run cross-validation for the shrink_mode and/or lmb hyperparameters. As adaptive hierarchical shrinkage is a fully post-hoc procedure, cross-validation requires no retraining of the base model. This function exploits this property.
Tutorials
- General usage: Shows how to apply hierarchical shrinkage on a simple dataset and access feature importances.
- Cross-validating shrinkage parameters:
Hyperparameters for (augmented) hierarchical shrinkage (i.e.
shrink_modeandlmb) can be tuned using cross-validation, without having to retrain the underlying model. This is because (augmented) hierarchical shrinkage is a fully post-hoc procedure. As theShrinkageClassifierandShrinkageRegressorare valid scikit-learn estimators, you could simply tune these hyperparameters usingGridSearchCVas you would do with any other scikit-learn model. However, this will retrain the decision tree or random forest, which leads to unnecessary performance loss. This notebook shows how you can use our cross-validation function to cross-validateshrink_modeandlmbwithout this performance loss.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file adhs-0.1.3.tar.gz.
File metadata
- Download URL: adhs-0.1.3.tar.gz
- Upload date:
- Size: 13.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e8a5e252759c1efc7ff605e1dfd7f92b408d2968f1426686c692a614406fe750
|
|
| MD5 |
c7643583bfbeaea9d0056109b83d9c2c
|
|
| BLAKE2b-256 |
d11c6a3304e0959baa554500adea9463a06722eddae39d5cf85e8c38b4343e42
|
File details
Details for the file adhs-0.1.3-py3-none-any.whl.
File metadata
- Download URL: adhs-0.1.3-py3-none-any.whl
- Upload date:
- Size: 12.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
33de81ac34e980dff8961cf0b94ab0a0037aa9d6462b778bb6885fc641ade2c6
|
|
| MD5 |
fdbfabf15b59ae1fb2ce6c5397f0a54f
|
|
| BLAKE2b-256 |
190d635d18c1009bc1336fd1616d526fcdee903bc2e00e0c48095323909cad43
|