Skip to main content

A novel method for learning hard, axis-aligned decision trees with gradient descent.

Project description

🌳 GradTree: Gradient-Based Decision Trees 🌳

🌳 GradTree is a novel approach for learning hard, axis-aligned decision trees with gradient descent!

🔍 What's new?

  • Reformulation of decision trees to dense representations
  • Approximation of step function with sigmoids and entmax function
  • ST operator to retain inductive bias of hard, axis-aligned splits

📝 Details on the method can be found in the preprint available under: https://arxiv.org/abs/2305.03515

Installation

To download the latest official release of the package, use a pip command below:

pip install GradTree

More details can be found under: https://pypi.org/project/GradTree/

Usage

Example usage is in the following or available in GradTree_minimal_example.ipynb. Please note that a GPU is required to achieve competitive runtimes.

Load Data

from sklearn.model_selection import train_test_split
import openml

dataset = openml.datasets.get_dataset(40536)
X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
categorical_feature_indices = [idx if idx_bool for idx, idx_bool in enumerate(categorical_indicator)]

X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

X_train, X_valid, y_train, y_valid = train_test_split(X_temp, y_temp, test_size=0.2, random_state=42)

y_train = y_train.values.codes.astype(np.float64)
y_valid = y_valid.values.codes.astype(np.float64)
y_test = y_test.values.codes.astype(np.float64)

Preprocessing, Hyperparameters and Training

GradTree requires categorical features to be encoded appropriately. The best results are achieved using Leave-One-Out Encoding for high-cardinality categorical features and One-Hot Encoding for low-cardinality categorical features. Furthermore, all features should be normalized using a quantile transformation. Passing the categorical indices to the model wil automatically preprocess the data accordingly.

In the following, we will train the model using the default parameters. GradTree already archives great results with its default parameters, but a HPO can increase the performance even further. An appropriate grid is specified in the model class.

from GradTree import GradTree

params = {
        'depth': 5,
        'n_estimators': 2048,

        'learning_rate_weights': 0.005,
        'learning_rate_index': 0.01,
        'learning_rate_values': 0.01,
        'learning_rate_leaf': 0.01,

        'optimizer': 'SWA',
        'cosine_decay_steps': 0,

        'initializer': 'RandomNormal',

        'loss': 'crossentropy',
        'focal_loss': False,
        'temperature': 0.0,

        'from_logits': True,
        'apply_class_balancing': True,

        'dropout': 0.0,

        'selected_variables': 0.8,
        'data_subset_fraction': 1.0,
}

args = {
    'epochs': 1_000,
    'early_stopping_epochs': 25,
    'batch_size': 64,

    'cat_idx': categorical_feature_indices, # put list of categorical indices
    'objective': 'binary',
    
    'metrics': ['F1'], # F1, Accuracy, R2
    'random_seed': 42,
    'verbose': 1,       
}

model_gradtree = GradTree(params=params, args=args)

model_gradtree.fit(X_train=X_train,
          y_train=y_train,
          X_val=X_valid,
          y_val=y_valid)

model_gradtree = model_gradtree.predict(X_test)

Evaluate Model

preds = model_gradtree.predict(X_test)

accuracy = sklearn.metrics.accuracy_score(y_test, np.round(preds[:,1]))
f1_score = sklearn.metrics.f1_score(y_test, np.round(preds[:,1]), average='macro')
roc_auc = sklearn.metrics.roc_auc_score(y_test, preds[:,1], average='macro')

print('Accuracy:', accuracy)
print('F1 Score:', f1_score)
print('ROC AUC:', roc_auc)

More

Please note that this is an experimental implementation which is not fully tested yet. If you encounter any errors, or you observe unexpected behavior, please let me know.

The code for reproducing the experiments from the paper now is in a separate folder ./experiments_paper_gradtree/

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

GradTree-0.0.1.tar.gz (11.5 kB view details)

Uploaded Source

Built Distribution

GradTree-0.0.1-py3-none-any.whl (11.9 kB view details)

Uploaded Python 3

File details

Details for the file GradTree-0.0.1.tar.gz.

File metadata

  • Download URL: GradTree-0.0.1.tar.gz
  • Upload date:
  • Size: 11.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for GradTree-0.0.1.tar.gz
Algorithm Hash digest
SHA256 1a42cdab38e4afdf8dc774c5b724ff135f24930d9f4682c3529bcb34bad5d194
MD5 b2662394105a501a0e59a89a05e30901
BLAKE2b-256 425533b4bc7dddfeb235ba5136d5060dcec60cfae32ad140724d9056ea468383

See more details on using hashes here.

File details

Details for the file GradTree-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: GradTree-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 11.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for GradTree-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a5ca4d01e1636dd94898204f2dfd4e40a3f98177558abfd1f99a0720bf00d829
MD5 bff714bd7fc773f30dbd5ab8f7f44b14
BLAKE2b-256 483173234c02de52adf1798304b1ed9ab7a81d1492780564b7ef1855d665b192

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page