Skip to main content

A novel method for learning hard, axis-aligned decision trees with gradient descent.

Project description

🌳 GradTree: Gradient-Based Decision Trees 🌳

🌳 GradTree is a novel approach for learning hard, axis-aligned decision trees with gradient descent!

🔍 What's new?

  • Reformulation of decision trees to dense representations
  • Approximation of step function with sigmoids and entmax function
  • ST operator to retain inductive bias of hard, axis-aligned splits

📝 Details on the method can be found in the preprint available under: https://arxiv.org/abs/2305.03515

Installation

To download the latest official release of the package, use a pip command below:

pip install GradTree

More details can be found under: https://github.com/s-marton/GradTree

Cite us

@inproceedings{marton2023gradtree,
  title={GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent},
  author={Marton, Sascha and L{\"u}dtke, Stefan and Bartelt, Christian and Stuckenschmidt, Heiner},
  booktitle={NeurIPS 2023 Second Table Representation Learning Workshop},
  year={2023}
}

Usage

Example usage is in the following or available in out git repository. Please note that a GPU is required to achieve competitive runtimes.

Load Data

from sklearn.model_selection import train_test_split
import openml

dataset = openml.datasets.get_dataset(40536)
X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
categorical_feature_indices = [idx if idx_bool for idx, idx_bool in enumerate(categorical_indicator)]

X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

X_train, X_valid, y_train, y_valid = train_test_split(X_temp, y_temp, test_size=0.2, random_state=42)

y_train = y_train.values.codes.astype(np.float64)
y_valid = y_valid.values.codes.astype(np.float64)
y_test = y_test.values.codes.astype(np.float64)

Preprocessing, Hyperparameters and Training

GradTree requires categorical features to be encoded appropriately. The best results are achieved using Leave-One-Out Encoding for high-cardinality categorical features and One-Hot Encoding for low-cardinality categorical features. Furthermore, all features should be normalized using a quantile transformation. Passing the categorical indices to the model wil automatically preprocess the data accordingly.

In the following, we will train the model using the default parameters. GradTree already archives great results with its default parameters, but a HPO can increase the performance even further. An appropriate grid is specified in the model class.

from GradTree import GradTree

params = {
        'depth': 5,

        'learning_rate_index': 0.01,
        'learning_rate_values': 0.01,
        'learning_rate_leaf': 0.005,

        'optimizer': 'SWA',
        'cosine_decay_steps': 0,

        'initializer': 'RandomNormal',

        'loss': 'crossentropy',
        'focal_loss': False,
        'temperature': 0.0,

        'apply_class_balancing': True,
}

args = {
    'epochs': 1_000,
    'early_stopping_epochs': 25,
    'batch_size': 64,

    'cat_idx': categorical_feature_indices, # put list of categorical indices
    'objective': 'binary',
    
    'metrics': ['F1'], # F1, Accuracy, R2
    'random_seed': 42,
    'verbose': 1,       
}

model_gradtree = GradTree(params=params, args=args)

model_gradtree.fit(X_train=X_train,
          y_train=y_train,
          X_val=X_valid,
          y_val=y_valid)

model_gradtree = model_gradtree.predict(X_test)

Evaluate Model

preds = model_gradtree.predict(X_test)

accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds[:,1]))
f1_score = sklearn.metrics.f1_score(y_test, np.round(preds[:,1]), average='macro')
roc_auc = sklearn.metrics.roc_auc_score(y_test, preds[:,1], average='macro')

print('Accuracy:', accuracy)
print('F1 Score:', f1_score)
print('ROC AUC:', roc_auc)

More

Please note that this is an experimental implementation which is not fully tested yet. If you encounter any errors, or you observe unexpected behavior, please let me know.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

GradTree-0.1.0.tar.gz (11.5 kB view details)

Uploaded Source

Built Distribution

GradTree-0.1.0-py3-none-any.whl (11.9 kB view details)

Uploaded Python 3

File details

Details for the file GradTree-0.1.0.tar.gz.

File metadata

  • Download URL: GradTree-0.1.0.tar.gz
  • Upload date:
  • Size: 11.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for GradTree-0.1.0.tar.gz
Algorithm Hash digest
SHA256 3ab23321bfead319567e13b32027d7e493e438d7cdde158b99efa09823bb41eb
MD5 ff59cdf79144a10654b11e709eb08be7
BLAKE2b-256 ce0e1de9320988b997f40ebc3772a5ee367b62ef5eb6876053a5de7bddb1a027

See more details on using hashes here.

File details

Details for the file GradTree-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: GradTree-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 11.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for GradTree-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6a61a6091b21b42a00eade692c7f0927c2b5de7a83e5805b68d4c4db78e40fd6
MD5 d03fe536ed22c107c934e8909f479ab5
BLAKE2b-256 dfcacc014edf6e249b9addfc081ee714078e3f12485a23bb87a3ba7d3e40ab4d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page