Skip to main content

Bayesian post-hoc regularization for random forests

Project description

This above picture is from (https://www.blackandwhite.ie/mononeil/blurred-forest)

Bayesian post-hoc regularization for random forests

Method Description

Random Forests are powerful ensemble learning algorithms widely used in various machine learning tasks. However, they have a tendency to overfit noisy or irrelevant features, which can result in decreased generalization performance. Post-hoc regularization techniques aim to mitigate this issue by modifying the structure of the learned ensemble after its training.

Here, we propose Bayesian post-hoc regularization to leverage the reliable patterns captured by leaf nodes closer to the root, while potentially reducing the impact of more specific and potentially noisy leaf nodes deeper in the tree. This approach allows for a form of pruning that does not alter the general structure of the trees but rather adjusts the influence of leaf nodes based on their proximity to the root node. We have evaluated the performance of our method on various machine learning data sets. Our approach demonstrates competitive performance with the state-of-the-art methods and, in certain cases, surpasses them in terms of predictive accuracy and generalization.

Code Description

All classes inherit from ShrinkageEstimator, which extends sklearn.base.BaseEstimator. Usage of these two classes is entirely analogous, and works just like any other sklearn estimator:

  • __init__() parameters:
    • base_estimator: the estimator around which we "wrap" hierarchical shrinkage. This should be a tree-based estimator: DecisionTreeClassifier, RandomForestClassifier, ... (analogous for Regressors)
    • shrink_mode: 2 options:
      • "hs": classical Hierarchical Shrinkage (from Agarwal et al. 2022)
      • "beta": Bayesian post-hoc regularization (from Pfeifer 2023)
    • lmb: lambda hyperparameter
    • alpha: alpha hyperparameter
    • beta: beta hyperparameter
    • random_state: random state for reproducibility
  • Other functions: fit(X, y), predict(X), predict_proba(X), score(X, y) work just like with any other sklearn estimator.

Usage

Install the Python package treesmoothing via pip

pip install treesmoothing

and import the ShrinkageClassifier as

from treesmoothing import ShrinkageClassifier

or install locally import of main function from source

pip install ./treesmoothing
from treesmooting import ShrinkageClassifier

Other imports

from imodels.util.data_util import get_clean_dataset
import numpy as np
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt
from sklearn.model_selection import GridSearchCV
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import roc_auc_score

import sys

Example data set

clf_datasets = [
    ("breast-cancer", "breast_cancer", "imodels")
]

# scoring
#sc = "balanced_accuracy"
sc = "roc_auc"

# number of trees 
ntrees = 10

# Read in data set
X, y, feature_names = get_clean_dataset('breast_cancer', data_source='imodels')

# train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20)

scores = {}

scores["vanilla"] = []
scores["hs"] = []
scores["beta"] = []

Vanilla Random Forest

# vanilla RF ##########################################
print("Vanilla Mode")
shrink_mode="vanilla"
#######################################################

clf = RandomForestClassifier(n_estimators=ntrees) 
clf.fit(X_train, y_train)
if sc == "balanced_accuracy":
    pred_vanilla = clf.predict(X_test)
    scores[shrink_mode].append(balanced_accuracy_score(y_test, pred_vanilla))    
if sc == "roc_auc":
    pred_vanilla = clf.predict_proba(X_test)[:,1]
    scores[shrink_mode].append(roc_auc_score(y_test, pred_vanilla))    

Hierarchical Shrinkage from Agarwal et al. 2022

# hs - Hierarchical Shrinkage #########################
print("HS Mode")
shrink_mode="hs"
#######################################################

param_grid = {
"lmb": [0.001, 0.01, 0.1, 1, 10, 25, 50, 100, 200],
"shrink_mode": ["hs"]}

grid_search = GridSearchCV(ShrinkageClassifier(RandomForestClassifier(n_estimators=ntrees)), 
param_grid, cv=5, n_jobs=-1, scoring=sc)

grid_search.fit(X_train, y_train)

best_params = grid_search.best_params_
print(best_params)

clf = ShrinkageClassifier(RandomForestClassifier(n_estimators=ntrees),shrink_mode=shrink_mode, 
lmb=best_params.get('lmb'))

clf.fit(X_train, y_train)
if sc == "balanced_accuracy":
    pred_hs = clf.predict(X_test)
    scores[shrink_mode].append(balanced_accuracy_score(y_test, pred_hs))      
if sc == "roc_auc":
    pred_hs = clf.predict_proba(X_test)[:,1]
    scores[shrink_mode].append(roc_auc_score(y_test, pred_hs))    

Bayesian post-hoc regularization from Pfeifer 2023

# beta - Bayesian post-hoc regularization #########################
print("Beta Shrinkage")
shrink_mode="beta"
###################################################################

param_grid = {
"alpha": [1500, 1000, 800, 500, 100, 50, 30, 10, 1],
"beta": [1500, 1000, 800, 500, 100, 50, 30, 10, 1],
"shrink_mode": ["beta"]}

grid_search = GridSearchCV(ShrinkageClassifier
(RandomForestClassifier(n_estimators=ntrees)), param_grid, cv=5,
n_jobs=-1, scoring=sc)

grid_search.fit(X, y)

best_params = grid_search.best_params_
print(best_params)
clf = ShrinkageClassifier(RandomForestClassifier(n_estimators=ntrees),shrink_mode=shrink_mode, 
alpha=best_params.get('alpha'), beta=best_params.get('beta'))
clf.fit(X_train, y_train)

if sc == "balanced_accuracy":
    pred_beta = clf.predict(X_test)
    scores[shrink_mode].append(balanced_accuracy_score(y_test, pred_beta))      
if sc == "roc_auc":
    pred_beta = clf.predict_proba(X_test)[:,1]
    scores[shrink_mode].append(roc_auc_score(y_test, pred_beta))    

Print the results

print(scores)

Acknowledgement

The TreeSmoothing Python code was written by Bastian Pfeifer and Arne Gevaert. It is based on the Hierarchical Shrinkage implementation within the Python package imodels (https://github.com/csinva/imodels).

Citation

If you find the Bayesian post-hoc method useful please cite

@misc{pfeifer2023bayesian,
      title={Bayesian post-hoc regularization of random forests}, 
      author={Bastian Pfeifer},
      year={2023},
      eprint={2306.03702},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Bibtex References

@inproceedings{agarwal2022hierarchical,
  title={Hierarchical Shrinkage: Improving the accuracy and interpretability of tree-based models.},
  author={Agarwal, Abhineet and Tan, Yan Shuo and Ronen, Omer and Singh, Chandan and Yu, Bin},
  booktitle={International Conference on Machine Learning},
  pages={111--135},
  year={2022},
  organization={PMLR}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

treesmoothing-0.0.3.tar.gz (9.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

treesmoothing-0.0.3-py3-none-any.whl (7.9 kB view details)

Uploaded Python 3

File details

Details for the file treesmoothing-0.0.3.tar.gz.

File metadata

  • Download URL: treesmoothing-0.0.3.tar.gz
  • Upload date:
  • Size: 9.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.10.12

File hashes

Hashes for treesmoothing-0.0.3.tar.gz
Algorithm Hash digest
SHA256 961b4dec21d8b011009b9a6d521d9b41ad38d6f2afb51cacc3175a179a01a8ca
MD5 945e6e84241ae9834204630c0eb41225
BLAKE2b-256 faae44b5d7f4140a0e81903dbbfab0124974ebcaf3648349f03d52b37b6adc96

See more details on using hashes here.

File details

Details for the file treesmoothing-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: treesmoothing-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 7.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.10.12

File hashes

Hashes for treesmoothing-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 f6c395d3283deda5f66876d3140db4bd959adb8a99374d65bd28fd4446a4f30c
MD5 8d30e6b5e8638146130e71f51bb8ea22
BLAKE2b-256 53f2401f24e9fdf816a6349d5ebed86d00b75134f70058334543a22eb0736a9e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page