Skip to main content

A python library to build Model Trees with Linear Models at the leaves.

Project description

linear-tree

A python library to build Model Trees with Linear Models at the leaves.

linear-tree provides also the implementations of LinearForest and LinearBoost inspired from these works.

Overview

Linear Trees combine the learning ability of Decision Tree with the predictive and explicative power of Linear Models. Like in tree-based algorithms, the data are split according to simple decision rules. The goodness of slits is evaluated in gain terms fitting Linear Models in the nodes. This implies that the models in the leaves are linear instead of constant approximations like in classical Decision Trees.

Linear Forests generalize the well known Random Forests by combining Linear Models with the same Random Forests. The key idea is to use the strength of Linear Models to improve the nonparametric learning ability of tree-based algorithms. Firstly, a Linear Model is fitted on the whole dataset, then a Random Forest is trained on the same dataset but using the residuals of the previous steps as target. The final predictions are the sum of the raw linear predictions and the residuals modeled by the Random Forest.

Linear Boosting is a two stage learning process. Firstly, a linear model is trained on the initial dataset to obtain predictions. Secondly, the residuals of the previous step are modeled with a decision tree using all the available features. The tree identifies the path leading to highest error (i.e. the worst leaf). The leaf contributing to the error the most is used to generate a new binary feature to be used in the first stage. The iterations continue until a certain stopping criterion is met.

linear-tree is developed to be fully integrable with scikit-learn. LinearTreeRegressor and LinearTreeClassifier are provided as scikit-learn BaseEstimator to build a decision tree using linear estimators. LinearForestRegressor and LinearForestClassifier use the RandomForest from sklearn to model residuals. LinearBoostRegressor and LinearBoostClassifier are available also as TransformerMixin in order to be integrated, in any pipeline, also for automated features engineering. All the models available in sklearn.linear_model can be used as base learner.

Installation

pip install --upgrade linear-tree

The module depends on NumPy, SciPy and Scikit-Learn (>=0.24.2). Python 3.6 or above is supported.

Media

Usage

Linear Tree Regression
from sklearn.linear_model import LinearRegression
from lineartree import LinearTreeRegressor
from sklearn.datasets import make_regression
X, y = make_regression(n_samples=100, n_features=4,
                       n_informative=2, n_targets=1,
                       random_state=0, shuffle=False)
regr = LinearTreeRegressor(base_estimator=LinearRegression())
regr.fit(X, y)
Linear Tree Classification
from sklearn.linear_model import RidgeClassifier
from lineartree import LinearTreeClassifier
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=100, n_features=4,
                           n_informative=2, n_redundant=0,
                           random_state=0, shuffle=False)
clf = LinearTreeClassifier(base_estimator=RidgeClassifier())
clf.fit(X, y)
Linear Forest Regression
from sklearn.linear_model import LinearRegression
from lineartree import LinearForestRegressor
from sklearn.datasets import make_regression
X, y = make_regression(n_samples=100, n_features=4,
                       n_informative=2, n_targets=1,
                       random_state=0, shuffle=False)
regr = LinearForestRegressor(base_estimator=LinearRegression())
regr.fit(X, y)
Linear Forest Classification
from sklearn.linear_model import LinearRegression
from lineartree import LinearForestClassifier
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=100, n_features=4,
                           n_informative=2, n_redundant=0,
                           random_state=0, shuffle=False)
clf = LinearForestClassifier(base_estimator=LinearRegression())
clf.fit(X, y)
Linear Boosting Regression
from sklearn.linear_model import LinearRegression
from lineartree import LinearBoostRegressor
from sklearn.datasets import make_regression
X, y = make_regression(n_samples=100, n_features=4,
                       n_informative=2, n_targets=1,
                       random_state=0, shuffle=False)
regr = LinearBoostRegressor(base_estimator=LinearRegression())
regr.fit(X, y)
Linear Boosting Classification
from sklearn.linear_model import RidgeClassifier
from lineartree import LinearBoostClassifier
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=100, n_features=4,
                           n_informative=2, n_redundant=0,
                           random_state=0, shuffle=False)
clf = LinearBoostClassifier(base_estimator=RidgeClassifier())
clf.fit(X, y)

More examples in the notebooks folder.

Check the API Reference to see the parameter configurations and the available methods.

Examples

Show the linear tree learning path:

plot tree

Linear Tree Regressor at work:

linear tree regressor

Linear Tree Classifier at work:

linear tree classifier

Extract and examine coefficients at the leaves:

leaf coefficients

Impact of the features automatically generated with Linear Boosting:

linear_boost_importances

Comparing predictions of Linear Forest and Random Forest:

linear_forest_predictions

References

  • Regression-Enhanced Random Forests. Haozhe Zhang, Dan Nettleton, Zhengyuan Zhu.
  • Explainable boosted linear regression for time series forecasting. Igor Ilic, Berk Gorgulu, Mucahit Cevik, Mustafa Gokce Baydogan.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

linear-tree-0.3.5.tar.gz (18.9 kB view details)

Uploaded Source

Built Distribution

linear_tree-0.3.5-py3-none-any.whl (21.1 kB view details)

Uploaded Python 3

File details

Details for the file linear-tree-0.3.5.tar.gz.

File metadata

  • Download URL: linear-tree-0.3.5.tar.gz
  • Upload date:
  • Size: 18.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0.post20200714 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for linear-tree-0.3.5.tar.gz
Algorithm Hash digest
SHA256 2db9fc976bcd693a66d8d92fdd7f97314125b3330eea4778885bfe62190d586c
MD5 0b3627a1e1f04a59a277efce28a8a304
BLAKE2b-256 63b90e5237f1573f219c6b0bcce8b2ebc7186e1a5d35c97bc7645deda13e1fad

See more details on using hashes here.

File details

Details for the file linear_tree-0.3.5-py3-none-any.whl.

File metadata

  • Download URL: linear_tree-0.3.5-py3-none-any.whl
  • Upload date:
  • Size: 21.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0.post20200714 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for linear_tree-0.3.5-py3-none-any.whl
Algorithm Hash digest
SHA256 a87766a40cf27eefed0866e3f7bd086f91fb0e0a73e49b5169f6b7606ade7361
MD5 a59c59546a7bcb05e2e2889f38f51586
BLAKE2b-256 77508ea8a700140100353feade3b2f77e6999172adb9287d77c5ddc910599fb0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page