Skip to main content

A simple interface to modify scikit-learn's generated DOT representation of a decision tree.

Project description

DrawScikitTreeLogo

DrawScikitTree

A simple interface to modify scikit-learn's generated DOT string representation of a trained decision tree. Some basic function include changing the shape and color of each node, and tracing the decision paths taken for a test sample.

Installation

pip install draw-scikit-tree

Basic usage

Using the iris dataset as the classical example.

from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
X, y = iris.data, iris.target
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)

Next, use the trained classifier to initialize the TreeGraph object.

from DrawScikitTree import TreeGraph
treeGraph = TreeGraph(clf, impurity=False, label="none", fontname="Arial")

To trace the decisions paths taken for some test samples, use the .trace_paths(X_sample) function.

import numpy as np
import graphviz

# Get some random samples
random_indices = np.random.randint(X.shape[0], size=5)
X_sample = X[random_indices, :]

# Setting verbose=True will print out the decision paths for each sample
treeGraph.trace_paths(X_sample, color="red", verbose=True)

# Displaying the newly modified tree
new_dot_data = treeGraph.export()
graph = graphviz.Source(new_dot_data)
display(graph)
ExampleTree

For more examples check out the examples.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

draw_scikit_tree-0.1.1.tar.gz (255.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

draw_scikit_tree-0.1.1-py3-none-any.whl (8.2 kB view details)

Uploaded Python 3

File details

Details for the file draw_scikit_tree-0.1.1.tar.gz.

File metadata

  • Download URL: draw_scikit_tree-0.1.1.tar.gz
  • Upload date:
  • Size: 255.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.7.13

File hashes

Hashes for draw_scikit_tree-0.1.1.tar.gz
Algorithm Hash digest
SHA256 9455c900669a8008558625a8d18610d0bb4236b1c7ca221d553de91dd9462e9c
MD5 d3d57ce2f77f78a2ba824cdc8e175fbb
BLAKE2b-256 1d1f586e36eb41ae33f099c008225d8f3f3aeed8e2bc98b43fd4673cfd80be67

See more details on using hashes here.

File details

Details for the file draw_scikit_tree-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for draw_scikit_tree-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 be4d3ca79fafa6142886811d19991f81b183ee874524d4ac42241916b21a8b6c
MD5 358c66333b81ee5bdde80d12239ca86f
BLAKE2b-256 2dfab298727ed5348f7e298b5e10a73b29890b10b692acc90b2a0fad645c3731

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page