Adaptive Hierarchical Shrinkage
Project description
Scikit-Learn-compatible implementation of Adaptive Hierarchical Shrinkage
This directory contains an implementation of Adaptive Hierarchical Shrinkage that is compatible with Scikit-Learn. It exports 2 classes:
ShrinkageClassifier
ShrinkageRegressor
Installation
adhs
Package
The adhs
package, which contains the implementations of
Adaptive Hierarchical Shrinkage, can be installed using:
pip install .
Experiments
To be able to run the scripts in the experiments
directory, some extra
requirements are needed. These can be installed in a new conda
environment as follows:
conda create -n shrinkage python=3.10
conda activate shrinkage
pip install .[experiments]
Basic API
This package exports 2 classes and 1 method:
ShrinkageClassifier
ShrinkageRegressor
cross_val_shrinkage
ShrinkageClassifier
and ShrinkageRegressor
Both classes inherit from ShrinkageEstimator
, which extends sklearn.base.BaseEstimator
. Adaptive hierarchical shrinkage can be summarized as follows:
$$
\hat{f}(\mathbf{x}) = \mathbb{E}{t_0}[y] + \sum{l=1}^L\frac{\mathbb{E}{t_l}[y] - \mathbb{E}{t_{l-1}}[y]}{1 + \frac{g(t_{l-1})}{N(t_{l-1})}}
$$
where $g(t_{l-1})$ is some function of the node $t_{l-1}$. Classical hierarchical shrinkage (Agarwal et al. 2022) corresponds to $g(t_{l-1}) = \lambda$, where $\lambda$ is a chosen constant.
__init__()
parameters:base_estimator
: the estimator around which we "wrap" hierarchical shrinkage. This should be a tree-based estimator:DecisionTreeClassifier
,RandomForestClassifier
, ... (analogous forRegressor
s)shrink_mode
: 6 options:"no_shrinkage"
: dummy value. This setting will not influence thebase_estimator
in any way, and is equivalent to just using thebase_estimator
by itself. Added for easy comparison between different modes of shrinkage and no shrinkage at all."hs"
: classical Hierarchical Shrinkage (from Agarwal et al. 2022): $g(t_{l-1}) = \lambda$."hs_entropy"
: Adaptive Hierarchical Shrinkage with added entropy term: $g(t_{l-1}) = \lambda H(t_{l-1})$."hs_log_cardinality"
: Adaptive Hierarchical Shrinkage with log of cardinality term: $g(t_{l-1}) = \lambda \log C(t_{l-1})$ where $C(t)$ is the number of unique values in $t$."hs_permutation"
: Adaptive Hierarchical Shrinkage with $g(t_{l-1}) = \frac{1}{\alpha(t_{l-1})}$, with $\alpha(t_{l-1}) = 1 - \frac{\Delta_\mathcal{I}(t_{l-1}, { }\pi x(t{l-1})) + \epsilon}{\Delta_\mathcal{I}(t_{l-1}, x(t_{l-1}))+ \epsilon}$"hs_global_permutation"
: Same as"hs_permutation"
, but the data is permuted only once for the full dataset rather than once in each node.
lmb
: $\lambda$ hyperparameterrandom_state
: random state for reproducibility
reshrink(shrink_mode, lmb, X)
: changes the shrinkage mode and/or lambda value in the shrinkage process. Callingreshrink
with a given value ofshrink_mode
and/orlmb
on an existing model is equivalent to fitting a new model with the same base estimator but the new, given values forshrink_mode
and/orlmb
. This method can avoid redundant computations in the shrinkage process, so can be more efficient than re-fitting a newShrinkageClassifier
orShrinkageRegressor
.- Other functions:
fit(X, y)
,predict(X)
,predict_proba(X)
,score(X, y)
work just like with any othersklearn
estimator.
cross_val_shrinkage
This method can be used to efficiently run cross-validation for the shrink_mode
and/or lmb
hyperparameters. As adaptive hierarchical shrinkage is a fully post-hoc procedure, cross-validation requires no retraining of the base model. This function exploits this property.
Tutorials
- General usage: Shows how to apply hierarchical shrinkage on a simple dataset and access feature importances.
- Cross-validating shrinkage parameters:
Hyperparameters for (augmented) hierarchical shrinkage (i.e.
shrink_mode
andlmb
) can be tuned using cross-validation, without having to retrain the underlying model. This is because (augmented) hierarchical shrinkage is a fully post-hoc procedure. As theShrinkageClassifier
andShrinkageRegressor
are valid scikit-learn estimators, you could simply tune these hyperparameters usingGridSearchCV
as you would do with any other scikit-learn model. However, this will retrain the decision tree or random forest, which leads to unnecessary performance loss. This notebook shows how you can use our cross-validation function to cross-validateshrink_mode
andlmb
without this performance loss.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file adhs-0.1.3.tar.gz
.
File metadata
- Download URL: adhs-0.1.3.tar.gz
- Upload date:
- Size: 13.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e8a5e252759c1efc7ff605e1dfd7f92b408d2968f1426686c692a614406fe750 |
|
MD5 | c7643583bfbeaea9d0056109b83d9c2c |
|
BLAKE2b-256 | d11c6a3304e0959baa554500adea9463a06722eddae39d5cf85e8c38b4343e42 |
File details
Details for the file adhs-0.1.3-py3-none-any.whl
.
File metadata
- Download URL: adhs-0.1.3-py3-none-any.whl
- Upload date:
- Size: 12.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 33de81ac34e980dff8961cf0b94ab0a0037aa9d6462b778bb6885fc641ade2c6 |
|
MD5 | fdbfabf15b59ae1fb2ce6c5397f0a54f |
|
BLAKE2b-256 | 190d635d18c1009bc1336fd1616d526fcdee903bc2e00e0c48095323909cad43 |