Skip to main content

transfer parameters from lightgbm to differentiable decision trees!

Project description

TreeGrad

TreeGrad implements a naive approach to converting a Gradient Boosted Tree Model to an Online trainable model. It does this by creating differentiable tree models which can be learned via auto-differentiable frameworks. TreeGrad is in essence an implementation of Kontschieder, Peter, et al. "Deep neural decision forests." with extensions.

To install

python setup.py install

or alternatively from pypi

pip install treegrad

Run tests:

python -m nose2
@inproceedings{siu2019transferring,
  title={Transferring Tree Ensembles to Neural Networks},
  author={Siu, Chapman},
  booktitle={International Conference on Neural Information Processing},
  pages={471--480},
  year={2019},
  organization={Springer}
}

Link: https://arxiv.org/abs/1904.11132

Usage

from sklearn.
import treegrad as tgd

mod = tgd.TGDClassifier(num_leaves=31, max_depth=-1, learning_rate=0.1, n_estimators=100, autograd_config={'refit_splits':False})
mod.fit(X, y)
mod.partial_fit(X, y)

Requirments

The requirements for this package are:

  • lightgbm
  • scikit-learn
  • autograd

Future plans:

  • Add implementation for Neural Architecture search for decision boundary splits (requires a bit of clean up - TBA)
    • Implementation can be done quite trivially using objects residing in tree_utils.py - Challenge is getting this working in a sane manner with scikit-learn interface.
  • GPU enabled auto differentiation framework - see notebooks/ for progress off Colab for Tensorflow 2.0 port
  • support xgboost/lightgbm additional features such as monotone constraints
  • Support RegressorMixin

Results

When decision splits are reset and subsequently re-learned, TreeGrad can be competitive in performance with popular implementations (albeit an order of magnitude slower). Below is a table showing accuracy on test dataset on UCI benchmark datasets for Boosted Ensemble models (100 trees)

Dataset TreeGrad LightGBM Scikit-Learn (Gradient Boosting Classifier)
adult 0.860 0.873 0.874
covtype 0.832 0.835 0.826
dna 0.950 0.949 0.946
glass 0.766 0.813 0.719
mandelon 0.882 0.881 0.866
soybean 0.936 0.936 0.917
yeast 0.591 0.573 0.542

Implementation

To understand the implementation of TreeGrad, we interpret a decision tree algorithm to be a three layer neural network, where the layers are as follows:

  1. Node layer, which determines the decision boundaries
  2. Routing layer, which determines which nodes are used to route to the final leaf nodes
  3. Leaf layer, the layer which determines the final predictions

In the node layer, the decision boundaries can be interpreted as axis-parallel decision boundaries from your typical Linear Classifier; i.e. a fully connected dense layer

The routing layer requires a binary routing matrix to which essentially the global product routing is applied

The leaf layer is your typical fully connected dense layer.

This approach is the same as the one taken by Kontschieder, Peter, et al. "Deep neural decision forests."

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

treegrad-1.0.1-py3-none-any.whl (11.7 kB view details)

Uploaded Python 3

File details

Details for the file treegrad-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: treegrad-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 11.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.11.0 pkginfo/1.4.2 requests/2.19.1 setuptools/41.0.1 requests-toolbelt/0.8.0 tqdm/4.42.1 CPython/3.6.6

File hashes

Hashes for treegrad-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 213a617c38cfa7e2af8c6f821f6f0cbf06ba08246b78fb5a3611b92efdc09eea
MD5 a1f075086e047a0e0ae1f3326250388f
BLAKE2b-256 3244d6ae0a7731b4bb7df2ade6ac748b0cab89c498a7c642372166b84391db2c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page