Python package treeplot vizualizes a tree based on a randomforest or xgboost model.
Project description
treeplot
- treeplot is Python package to easily plot the tree derived from models such as decisiontrees, randomforest and xgboost. Developing explainable machine learning models is becoming more important in many domains. The most popular and classical explainable models are still tree based. Think of decision trees or random forest. The tree that is learned can be visualized and then explained. However, it can be a challange to simply plot the tree. Think of configuration issues with dot files, path locations to graphviz, differences across operating systems, differences across editors such as jupyter notebook, colab, spyder etc. This frustration led to this library, an easy way to plot the decision trees 🌲. It works for Random-forest, decission trees, xgboost and gradient boosting models. Under the hood it makes many checks, downloads graphviz, sets the path and then plots the tree.
Have fun!
Functions in treeplot
Treeplot can plot the tree for Random-forest, decission trees, xgboost and gradient boosting models:
- treeplot.plot() : Generic function to plot the tree of any of the four models with default settings
- treeplot.plot_tree() : Plot the decission tree model. Parameters can be specified.
- treeplot.randomforest() : Plot the randomforest model. Parameters can be specified.
- treeplot.xgboost() : Plot the xgboost model. Parameters can be specified.
- treeplot.import_example('iris') : Import example dataset
Contents
Installation
- Install treeplot from PyPI (recommended). treeplot is compatible with Python 3.6+ and runs on Linux, MacOS X and Windows.
- It is distributed under the MIT license.
Quick Start
pip install treeplot
- Alternatively, install treeplot from the GitHub source:
git clone https://github.com/erdogant/treeplot.git
cd treeplot
python setup.py install
Import treeplot package
import treeplot
Example RandomForest:
# Load example dataset
X,y = treeplot.import_example()
# Learn model
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(n_estimators=100, max_depth=2, random_state=0).fit(X, y)
# Make plot
ax = treeplot.plot(model)
# or directly
ax = treeplot.randomforest(model)
# If more parameters needs to be specified, use the exact function:
ax = treeplot.randomforest(model, export='pdf')
Example XGboost:
# Load example dataset
X,y = treeplot.import_example()
# Learn model
from xgboost import XGBClassifier
model = XGBClassifier(n_estimators=100, max_depth=2, random_state=0).fit(X, y)
# Make plot
ax = treeplot.plot(model)
# or directly
ax = treeplot.xgboost(model)
# If more parameters needs to be specified, use the exact function:
ax = treeplot.xgboost(model, plottype='vertical')
Maintainers
- Erdogan Taskesen, github: erdogant
Contribute
- Contributions are welcome.
Licence
See LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
treeplot-0.1.13.tar.gz
(8.7 kB
view details)
Built Distribution
File details
Details for the file treeplot-0.1.13.tar.gz
.
File metadata
- Download URL: treeplot-0.1.13.tar.gz
- Upload date:
- Size: 8.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.2.0.post20200511 requests-toolbelt/0.9.1 tqdm/4.46.0 CPython/3.8.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 58f8f8873514a249d5cfa576929bf93890b25a37570500f761fe7292a221580e |
|
MD5 | 8e9d82ecbf5154ef02026a5c8cd0eaaa |
|
BLAKE2b-256 | e921ad6b14550175c4cdf88792cc60845af2c9fd4cfc86626e860ca729b43d38 |
File details
Details for the file treeplot-0.1.13-py3-none-any.whl
.
File metadata
- Download URL: treeplot-0.1.13-py3-none-any.whl
- Upload date:
- Size: 8.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.2.0.post20200511 requests-toolbelt/0.9.1 tqdm/4.46.0 CPython/3.8.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ebe2f2c475cd013a5310d40a7606f31354710fc15655ae715ef2a469c257c930 |
|
MD5 | 1cb8908c5c3e1533a130957107adbecc |
|
BLAKE2b-256 | a792d4af3d9a597dfe7b421446c000cefaaf401d0fa2832e4626955463ef3397 |