A novel method for learning hard, axis-aligned decision trees with gradient descent.
Project description
🌳 GradTree: Gradient-Based Decision Trees 🌳
🌳 GradTree is a novel approach for learning hard, axis-aligned decision trees with gradient descent!
🔍 What's new?
- Reformulation of decision trees to dense representations
- Approximation of step function with sigmoids and entmax function
- ST operator to retain inductive bias of hard, axis-aligned splits
📝 Details on the method can be found in the preprint available under: https://arxiv.org/abs/2305.03515
Installation
To download the latest official release of the package, use a pip command below:
pip install GradTree
More details can be found under: https://github.com/s-marton/GradTree
Cite us
@inproceedings{marton2023gradtree,
title={GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent},
author={Marton, Sascha and L{\"u}dtke, Stefan and Bartelt, Christian and Stuckenschmidt, Heiner},
booktitle={NeurIPS 2023 Second Table Representation Learning Workshop},
year={2023}
}
Usage
Example usage is in the following or available in out git repository. Please note that a GPU is required to achieve competitive runtimes.
Load Data
from sklearn.model_selection import train_test_split
import openml
dataset = openml.datasets.get_dataset(40536)
X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
categorical_feature_indices = [idx if idx_bool for idx, idx_bool in enumerate(categorical_indicator)]
X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train, X_valid, y_train, y_valid = train_test_split(X_temp, y_temp, test_size=0.2, random_state=42)
y_train = y_train.values.codes.astype(np.float64)
y_valid = y_valid.values.codes.astype(np.float64)
y_test = y_test.values.codes.astype(np.float64)
Preprocessing, Hyperparameters and Training
GradTree requires categorical features to be encoded appropriately. The best results are achieved using Leave-One-Out Encoding for high-cardinality categorical features and One-Hot Encoding for low-cardinality categorical features. Furthermore, all features should be normalized using a quantile transformation. Passing the categorical indices to the model wil automatically preprocess the data accordingly.
In the following, we will train the model using the default parameters. GradTree already archives great results with its default parameters, but a HPO can increase the performance even further. An appropriate grid is specified in the model class.
from GradTree import GradTree
params = {
'depth': 5,
'learning_rate_index': 0.01,
'learning_rate_values': 0.01,
'learning_rate_leaf': 0.005,
'optimizer': 'SWA',
'cosine_decay_steps': 0,
'initializer': 'RandomNormal',
'loss': 'crossentropy',
'focal_loss': False,
'temperature': 0.0,
'apply_class_balancing': True,
}
args = {
'epochs': 1_000,
'early_stopping_epochs': 25,
'batch_size': 64,
'cat_idx': categorical_feature_indices, # put list of categorical indices
'objective': 'binary',
'metrics': ['F1'], # F1, Accuracy, R2
'random_seed': 42,
'verbose': 1,
}
model_gradtree = GradTree(params=params, args=args)
model_gradtree.fit(X_train=X_train,
y_train=y_train,
X_val=X_valid,
y_val=y_valid)
model_gradtree = model_gradtree.predict(X_test)
Evaluate Model
preds = model_gradtree.predict(X_test)
accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds[:,1]))
f1_score = sklearn.metrics.f1_score(y_test, np.round(preds[:,1]), average='macro')
roc_auc = sklearn.metrics.roc_auc_score(y_test, preds[:,1], average='macro')
print('Accuracy:', accuracy)
print('F1 Score:', f1_score)
print('ROC AUC:', roc_auc)
More
Please note that this is an experimental implementation which is not fully tested yet. If you encounter any errors, or you observe unexpected behavior, please let me know.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file GradTree-0.1.0.tar.gz
.
File metadata
- Download URL: GradTree-0.1.0.tar.gz
- Upload date:
- Size: 11.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3ab23321bfead319567e13b32027d7e493e438d7cdde158b99efa09823bb41eb |
|
MD5 | ff59cdf79144a10654b11e709eb08be7 |
|
BLAKE2b-256 | ce0e1de9320988b997f40ebc3772a5ee367b62ef5eb6876053a5de7bddb1a027 |
File details
Details for the file GradTree-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: GradTree-0.1.0-py3-none-any.whl
- Upload date:
- Size: 11.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6a61a6091b21b42a00eade692c7f0927c2b5de7a83e5805b68d4c4db78e40fd6 |
|
MD5 | d03fe536ed22c107c934e8909f479ab5 |
|
BLAKE2b-256 | dfcacc014edf6e249b9addfc081ee714078e3f12485a23bb87a3ba7d3e40ab4d |