Skip to main content

Implementation of General Optimal Sparse Decision Tree

Project description

Fast Sparse Decision Tree Optimization via Reference Ensembles

This code creates optimized sparse decision trees. It is a direct competitor of CART[3] and C4.5 [6], as well as DL8.5[1], BinOct[7], and OSDT[4]. Its advantage over CART and C4.5 is that the trees are globally optimized, not constructed just from the top down. This makes it slower than CART, but it provides better solutions. On the other hand, it tends to be faster than other optimal decision tree methods because it uses bounds to limit the search space, and uses a black box model (a boosted decision tree) to “guess” information about the optimal tree. It takes only seconds or a few minutes on most datasets.

To make it run faster, please use the options to limit the depth of the tree, and increase the regularization parameter above 0.02. If you run the algorithm without a depth constraint or set the regularization too small, it will run more slowly.

This work builds on a number of innovations for scalable construction of optimal tree-based classifiers: Scalable Bayesian Rule Lists[8], CORELS[2], OSDT[4], and, most closely, GOSDT[5].

Table of Content


Installation

You may use the following commands to install GOSDT along with its dependencies on macOS, Ubuntu and Windows.
You need Python 3.7 or later to use the module gosdt in your project.

pip3 install attrs packaging editables pandas scikit-learn sortedcontainers gmpy2 matplotlib
pip3 install gosdt

Configuration

The configuration is a JSON object and has the following structure and default values:

{ 
  "regularization": 0.05,
  "depth_budget": 0,
  "reference_LB": false, 
  "path_to_labels": "",
  "time_limit": 0,
  "uncertainty_tolerance": 0.0,
  "upperbound": 0.0,
  "worker_limit": 1,
  "stack_limit": 0,
  "precision_limit": 0,
  "model_limit": 1,
  "verbose": false,
  "diagnostics": false,
  "balance": false,
  "look_ahead": true,
  "similar_support": true,
  "cancellation": true,
  "continuous_feature_exchange": false,
  "feature_exchange": false,
  "feature_transform": true,
  "rule_list": false,
  "non_binary": false,
  "costs": "",
  "model": "",
  "timing": "",
  "trace": "",
  "tree": "",
  "profile": ""
}

Key parameters

regularization

  • Values: Decimal within range [0,1]
  • Description: Used to penalize complexity. A complexity penalty is added to the risk in the following way.
    ComplexityPenalty = # Leaves x regularization
    
  • Default: 0.05
  • Note: We highly recommend setting the regularization to a value larger than 1/num_samples. A small regularization could lead to a longer training time. If a smaller regularization is preferred, you must set the parameter allow_small_reg to true, which by default is false.

allow_small_reg

  • Values: true or false
  • Description: Flag for allowing regularization < 1/n , where n = num_samples (if false, regularizations below 1/n are automatically set to 1/n)
  • Default: false

depth_budget

  • Values: Integers >= 1
  • Description: Used to set the maximum tree depth for solutions, counting a tree with just the root node as depth 1. 0 means unlimited.
  • Default: 0

reference_LB

  • Values: true or false
  • Description: Enables using a vector of misclassifications from another (reference) model to lower bound our own misclassifications
  • Default: false
  • Note: If reference_LB is set to true, you must provide a valid path_to_labels.

path_to_labels

  • Values: String representing a path to a file.
  • Description: This file must be a single-column csv representing a class prediction for each training observation (in the same order as for the training data, using the same class labels as for the training data, and predicting each class present in the training set at least once across all training points). Typically this csv is obtained by fitting a gradient boosted decision tree model on the training data, and saving its training set predictions as a csv file.
  • Example for a dataset with classes 1 and 0:
    predicted_class
    0
    1
    1
    1
    0
    
  • Default: Emptry string

time_limit

  • Values: Decimal greater than or equal to 0
  • Description: A time limit upon which the algorithm will terminate. If the time limit is reached, the algorithm will terminate with an error.
  • Special Cases: When set to 0, no time limit is imposed.
  • Default: 0

More parameters

Flag

balance

  • Values: true or false
  • Description: Enables overriding the sample importance by equalizing the importance of each present class
  • Default: false

cancellation

  • Values: true or false
  • Description: Enables propagate up the dependency graph of task cancellations
  • Default: true

look_ahead

  • Values: true or false
  • Description: Enables the one-step look-ahead bound implemented via scopes
  • Default: true

similar_support

  • Values: true or false
  • Description: Enables the similar support bound imeplemented via the distance index
  • Default: true

feature_exchange

  • Values: true or false
  • Description: Enables pruning of pairs of features using subset comparison
  • Default: false

continuous_feature_exchange

  • Values: true or false
  • Description: Enables pruning of pairs continuous of feature thresholds using subset comparison
  • Default: false

feature_transform

  • Values: true or false
  • Description: Enables the equivalence discovery through simple feature transformations
  • Default: true

rule_list

  • Values: true or false
  • Description: Enables rule-list constraints on models
  • Default: false

non_binary

  • Values: true or false
  • Description: Enables non-binary encoding
  • Default: false

diagnostics

  • Values: true or false
  • Description: Enables printing of diagnostic trace when an error is encountered to standard output
  • Default: false

verbose

  • Values: true or false
  • Description: Enables printing of configuration, progress, and results to standard output
  • Default: false

Tuners

uncertainty_tolerance

  • Values: Decimal within range [0,1]
  • Description: Used to allow early termination of the algorithm. Any models produced as a result are guaranteed to score within the lowerbound and upperbound at the time of termination. However, the algorithm does not guarantee that the optimal model is within the produced model unless the uncertainty value has reached 0.
  • Default: 0.0

upperbound

  • Values: Decimal within range [0,1]
  • Description: Used to limit the risk of model search space. This can be used to ensure that no models are produced if even the optimal model exceeds a desired maximum risk. This also accelerates learning if the upperbound is taken from the risk of a nearly optimal model.
  • Special Cases: When set to 0, the bound is not activated.
  • Default: 0.0

Limits

model_limit

  • Values: Decimal greater than or equal to 0
  • Description: The maximum number of models that will be extracted into the output.
  • Special Cases: When set to 0, no output is produced.
  • Default: 1

precision_limit

  • Values: Decimal greater than or equal to 0
  • Description: The maximum number of significant figures considered when converting ordinal features into binary features.
  • Special Cases: When set to 0, no limit is imposed.
  • Default: 0

stack_limit

  • Values: Decimal greater than or equal to 0
  • Description: The maximum number of bytes considered for use when allocating local buffers for worker threads.
  • Special Cases: When set to 0, all local buffers will be allocated from the heap.
  • Default: 0

worker_limit

  • Values: Decimal greater than or equal to 1
  • Description: The maximum number of threads allocated to executing th algorithm.
  • Special Cases: When set to 0, a single thread is created for each core detected on the machine.
  • Default: 1

Files

costs

  • Values: string representing a path to a file.
  • Description: This file must contain a CSV representing the cost matrix for calculating loss.
    • The first row is a header listing every class that is present in the training data
    • Each subsequent row contains the cost incurred of predicitng class i when the true class is j, where i is the row index and j is the column index
    • Example where each false negative costs 0.1 and each false positive costs 0.2 (and correct predictions costs 0.0):
      negative,positive
      0.0,0.1
      0.2,0.0
      
    • Example for multi-class objectives:
      class-A,class-B,class-C
      0.0,0.1,0.3
      0.2,0.0,0.1
      0.8,0.3,0.0
      
    • Note: costs values are not normalized, so high cost values lower the relative weight of regularization
  • Special Case: When set to empty string, a default cost matrix is used which represents unweighted training misclassification.
  • Default: Emptry string

model

  • Values: string representing a path to a file.
  • Description: The output models will be written to this file.
  • Special Case: When set to empty string, no model will be stored.
  • Default: Emptry string

profile

  • Values: string representing a path to a file.
  • Description: Various analytics will be logged to this file.
  • Special Case: When set to empty string, no analytics will be stored.
  • Default: Emptry string

timing

  • Values: string representing a path to a file.
  • Description: The training time will be appended to this file.
  • Special Case: When set to empty string, no training time will be stored.
  • Default: Emptry string

trace

  • Values: string representing a path to a directory.
  • Description: snapshots used for trace visualization will be stored in this directory
  • Special Case: When set to empty string, no snapshots are stored.
  • Default: Emptry string

tree

  • Values: string representing a path to a directory.
  • Description: snapshots used for trace-tree visualization will be stored in this directory
  • Special Case: When set to empty string, no snapshots are stored.
  • Default: Emptry string

Example

The [https://github.com/ubc-systopia/gosdt-guesses/](GOSDT source code repository) contains example code and datasets to run GOSDT with threshold guessing, lower bound guessing, and depth limits. The example python file is available in https://github.com/ubc-systopia/gosdt-guesses/gosdt/example.py. A tutorial ipython notebook is available in https://github.com/ubc-systopia/gosdt-guesses/gosdt/tutorial.ipynb.

The script below will run only if you clone the git repo and run there, however, it should serve as an example for how to use gosdt.

import pandas as pd
import numpy as np
import time
import pathlib
from sklearn.ensemble import GradientBoostingClassifier
from model.threshold_guess import compute_thresholds
from model.gosdt import GOSDT

# read the dataset
df = pd.read_csv("experiments/datasets/fico.csv")
X, y = df.iloc[:,:-1].values, df.iloc[:,-1].values
h = df.columns[:-1]

# GBDT parameters for threshold and lower bound guesses
n_est = 40
max_depth = 1

# guess thresholds
X = pd.DataFrame(X, columns=h)
print("X:", X.shape)
print("y:",y.shape)
X_train, thresholds, header, threshold_guess_time = compute_thresholds(X, y, n_est, max_depth)
y_train = pd.DataFrame(y)

# guess lower bound
start_time = time.perf_counter()
clf = GradientBoostingClassifier(n_estimators=n_est, max_depth=max_depth, random_state=42)
clf.fit(X_train, y_train.values.flatten())
warm_labels = clf.predict(X_train)
elapsed_time = time.perf_counter() - start_time
lb_time = elapsed_time

# save the labels from lower bound guesses as a tmp file and return the path to it.
labelsdir = pathlib.Path('/tmp/warm_lb_labels')
labelsdir.mkdir(exist_ok=True, parents=True)
labelpath = labelsdir / 'warm_label.tmp'
labelpath = str(labelpath)
pd.DataFrame(warm_labels, columns=["class_labels"]).to_csv(labelpath, header="class_labels",index=None)


# train GOSDT model
config = {
            "regularization": 0.001,
            "depth_budget": 5,
            "warm_LB": True,
            "path_to_labels": labelpath,
            "time_limit": 60,
            "similar_support": False
        }

model = GOSDT(config)

model.fit(X_train, y_train)

print("evaluate the model, extracting tree and scores", flush=True)

# get the results
train_acc = model.score(X_train, y_train)
n_leaves = model.leaves()
n_nodes = model.nodes()
time = model.utime

print("Model training time: {}".format(time))
print("Training accuracy: {}".format(train_acc))
print("# of leaves: {}".format(n_leaves))
print(model.tree)

Output

X: (10459, 23)
y: (10459,)
gosdt reported successful execution
training completed. 1.658/0.098/1.756 (user, system, wall), mem=364 MB
bounds: [0.290914..0.290914] (0.000000) loss=0.282914, iterations=13569
evaluate the model, extracting tree and scores
Model training time: 1.6584229469299316
Training accuracy: 0.7170857634573095
# of leaves: 8
if ExternalRiskEstimate<=67.5 = 1 and MSinceMostRecentInqexcl7days<=-7.5 = 1 then:
    predicted class: 1
    misclassification penalty: 0.027
    complexity penalty: 0.001

else if ExternalRiskEstimate<=67.5 != 1 and MSinceMostRecentInqexcl7days<=-7.5 = 1 then:
    predicted class: 0
    misclassification penalty: 0.006
    complexity penalty: 0.001

else if ExternalRiskEstimate<=74.5 = 1 and MSinceMostRecentInqexcl7days<=-7.5 != 1 and MSinceMostRecentInqexcl7days<=0.5 = 1 and PercentTradesWBalance<=80.5 = 1 then:
    predicted class: 1
    misclassification penalty: 0.071
    complexity penalty: 0.001

else if ExternalRiskEstimate<=74.5 != 1 and MSinceMostRecentInqexcl7days<=-7.5 != 1 and MSinceMostRecentInqexcl7days<=0.5 = 1 and PercentTradesWBalance<=80.5 = 1 then:
    predicted class: 0
    misclassification penalty: 0.061
    complexity penalty: 0.001

else if ExternalRiskEstimate<=78.5 = 1 and MSinceMostRecentInqexcl7days<=-7.5 != 1 and MSinceMostRecentInqexcl7days<=0.5 = 1 and PercentTradesWBalance<=80.5 != 1 then:
    predicted class: 1
    misclassification penalty: 0.033
    complexity penalty: 0.001

else if ExternalRiskEstimate<=78.5 != 1 and MSinceMostRecentInqexcl7days<=-7.5 != 1 and MSinceMostRecentInqexcl7days<=0.5 = 1 and PercentTradesWBalance<=80.5 != 1 then:
    predicted class: 0
    misclassification penalty: 0.005
    complexity penalty: 0.001

else if ExternalRiskEstimate<=67.5 = 1 and MSinceMostRecentInqexcl7days<=-7.5 != 1 and MSinceMostRecentInqexcl7days<=0.5 != 1 then:
    predicted class: 1
    misclassification penalty: 0.026
    complexity penalty: 0.001

else if ExternalRiskEstimate<=67.5 != 1 and MSinceMostRecentInqexcl7days<=-7.5 != 1 and MSinceMostRecentInqexcl7days<=0.5 != 1 then:
    predicted class: 0
    misclassification penalty: 0.054
    complexity penalty: 0.001


FAQs

If you run into any issues when running GOSDT, consult the FAQs first.


License

This software is licensed under a 3-clause BSD license (see the LICENSE file for details).


Related Work

[1] Aglin, G.; Nijssen, S.; and Schaus, P. 2020. Learning optimal decision trees using caching branch-and-bound search. In AAAI Conference on Artificial Intelligence, volume 34, 3146–3153.

[2] Angelino, E.; Larus-Stone, N.; Alabi, D.; Seltzer, M.; and Rudin, C. 2018. Learning Certifiably Optimal Rule Lists for Categorical Data. Journal of Machine Learning Research, 18(234): 1–78.

[3] Breiman, L.; Friedman, J.; Stone, C. J.; and Olshen, R. A. 1984. Classification and Regression Trees. CRC press.

[4] Hu, X.; Rudin, C.; and Seltzer, M. 2019. Optimal sparse decision trees. In Advances in Neural Information Processing Systems, 7267–7275.

[5] Lin, J.; Zhong, C.; Hu, D.; Rudin, C.; and Seltzer, M. 2020. Generalized and scalable optimal sparse decision trees. In International Conference on Machine Learning (ICML), 6150–6160.

[6] Quinlan, J. R. 1993. C4.5: Programs for Machine Learning. Morgan Kaufmann

[7] Verwer, S.; and Zhang, Y. 2019. Learning optimal classification trees using a binary linear program formulation. In AAAI Conference on Artificial Intelligence, volume 33, 1625–1632.

[8] Yang, H., Rudin, C., & Seltzer, M. (2017, July). Scalable Bayesian rule lists. In International Conference on Machine Learning (ICML) (pp. 3921-3930). PMLR.


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

gosdt-0.1.8-cp312-abi3-win_amd64.whl (800.8 kB view details)

Uploaded CPython 3.12+Windows x86-64

gosdt-0.1.8-cp311-abi3-win_amd64.whl (800.8 kB view details)

Uploaded CPython 3.11+Windows x86-64

gosdt-0.1.8-cp310-abi3-win_amd64.whl (800.8 kB view details)

Uploaded CPython 3.10+Windows x86-64

gosdt-0.1.8-cp39-abi3-win_amd64.whl (801.5 kB view details)

Uploaded CPython 3.9+Windows x86-64

gosdt-0.1.8-cp38-abi3-win_amd64.whl (821.4 kB view details)

Uploaded CPython 3.8+Windows x86-64

gosdt-0.1.8-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (506.3 kB view details)

Uploaded CPython 3.8+manylinux: glibc 2.17+ x86-64

gosdt-0.1.8-cp38-abi3-macosx_12_0_x86_64.whl (639.0 kB view details)

Uploaded CPython 3.8+macOS 12.0+ x86-64

File details

Details for the file gosdt-0.1.8-cp312-abi3-win_amd64.whl.

File metadata

  • Download URL: gosdt-0.1.8-cp312-abi3-win_amd64.whl
  • Upload date:
  • Size: 800.8 kB
  • Tags: CPython 3.12+, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.11

File hashes

Hashes for gosdt-0.1.8-cp312-abi3-win_amd64.whl
Algorithm Hash digest
SHA256 b7969dae61c7ac4a8cf1cbda1d237c556ebfadcf18a9314ecf3a59f156fe9a01
MD5 f3778a30bdf74b9e97170712508fd1f3
BLAKE2b-256 d0ef5a86a1b846f4589b93f297bea93bf84e4b3f4a97f73ceded2da44e556701

See more details on using hashes here.

File details

Details for the file gosdt-0.1.8-cp311-abi3-win_amd64.whl.

File metadata

  • Download URL: gosdt-0.1.8-cp311-abi3-win_amd64.whl
  • Upload date:
  • Size: 800.8 kB
  • Tags: CPython 3.11+, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.11

File hashes

Hashes for gosdt-0.1.8-cp311-abi3-win_amd64.whl
Algorithm Hash digest
SHA256 459824a8d3a8ffcb2a5e597c07659be2af0f537ccf911f04a54d83f4da4ca310
MD5 47c0c4f5692cf9d7abb1801ee9970e3c
BLAKE2b-256 1ebe0e2b954a2197ed0bfd77790f1d6a1a6888dff2c7b42893be2d1412925974

See more details on using hashes here.

File details

Details for the file gosdt-0.1.8-cp310-abi3-win_amd64.whl.

File metadata

  • Download URL: gosdt-0.1.8-cp310-abi3-win_amd64.whl
  • Upload date:
  • Size: 800.8 kB
  • Tags: CPython 3.10+, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.11

File hashes

Hashes for gosdt-0.1.8-cp310-abi3-win_amd64.whl
Algorithm Hash digest
SHA256 c36591c45aa40f503b6b6249b671fd8229589e863aefb251c3eb3e3d65858976
MD5 b1f749cd8bc63c2c9006a570e76c26c4
BLAKE2b-256 0151183471962d77f3ff6781b23e74abe65a2d8a18c0a3f02d3f8859e860c2ac

See more details on using hashes here.

File details

Details for the file gosdt-0.1.8-cp39-abi3-win_amd64.whl.

File metadata

  • Download URL: gosdt-0.1.8-cp39-abi3-win_amd64.whl
  • Upload date:
  • Size: 801.5 kB
  • Tags: CPython 3.9+, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.11

File hashes

Hashes for gosdt-0.1.8-cp39-abi3-win_amd64.whl
Algorithm Hash digest
SHA256 7702f5cfdb71011686cfbce2c5592a4a936135103d354ed7af0645efc17f9148
MD5 c3052db80baa6dd913cac43acbade50c
BLAKE2b-256 3dbce0a8b9474d739354c3a25ce47103dbdf029c34946dfe25bc73ac11439f06

See more details on using hashes here.

File details

Details for the file gosdt-0.1.8-cp38-abi3-win_amd64.whl.

File metadata

  • Download URL: gosdt-0.1.8-cp38-abi3-win_amd64.whl
  • Upload date:
  • Size: 821.4 kB
  • Tags: CPython 3.8+, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.11

File hashes

Hashes for gosdt-0.1.8-cp38-abi3-win_amd64.whl
Algorithm Hash digest
SHA256 6fe515f63d88cc05437dd72c8c01d4fefb54d87afc0ce1e04b99dcd2c0da3f4b
MD5 5cb67fd62f55ad512a5a0cf1c38d1909
BLAKE2b-256 dffa094687be03144f68e5c707b1eb38c635d0d47c8e2fd859eb77ce7a112260

See more details on using hashes here.

File details

Details for the file gosdt-0.1.8-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for gosdt-0.1.8-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 f9f8e8b5e6276f15a850df4d55d3ccb5ed2b5722be7bc3bc373f127f208900f3
MD5 41b60f6a18b354821fc6306afe65a7da
BLAKE2b-256 33e14a72c9559af8f5b0f8d271e22ef8be76700564adf2f71f994eafb9afad25

See more details on using hashes here.

File details

Details for the file gosdt-0.1.8-cp38-abi3-macosx_12_0_x86_64.whl.

File metadata

  • Download URL: gosdt-0.1.8-cp38-abi3-macosx_12_0_x86_64.whl
  • Upload date:
  • Size: 639.0 kB
  • Tags: CPython 3.8+, macOS 12.0+ x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.11

File hashes

Hashes for gosdt-0.1.8-cp38-abi3-macosx_12_0_x86_64.whl
Algorithm Hash digest
SHA256 082365876c499c3690d22850a7bc733c8abdd49d94c15992f1374d04054d8afb
MD5 da2809a5d630480a547d55aa7105c645
BLAKE2b-256 db1d82d12f0f48e43436e5d0e6aed0ea042c5e1931f4c7fce8ac748bcdca3a86

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page