Skip to main content

A simple interface to modify scikit-learn's generated DOT representation of a decision tree.

Project description

DrawScikitTreeLogo

DrawScikitTree

A simple interface to modify scikit-learn's generated DOT string representation of a trained decision tree. Some basic function include changing the shape and color of each node, and tracing the decision paths taken for a test sample.

Installation

pip install draw-scikit-tree

Basic usage

Using the iris dataset as the classical example.

from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
X, y = iris.data, iris.target
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)

Next, use the trained classifier to initialize the TreeGraph object.

from DrawScikitTree import TreeGraph
treeGraph = TreeGraph(clf, impurity=False, label="none", fontname="Arial")

To trace the decisions paths taken for some test samples, use the .trace_paths(X_sample) function.

import numpy as np
import graphviz

# Get some random samples
random_indices = np.random.randint(X.shape[0], size=5)
X_sample = X[random_indices, :]

# Setting verbose=True will print out the decision paths for each sample
treeGraph.trace_paths(X_sample, color="red", verbose=True)

# Displaying the newly modified tree
new_dot_data = treeGraph.export()
graph = graphviz.Source(new_dot_data)
display(graph)
ExampleTree

For more examples check out the examples.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

draw_scikit_tree-0.1.3.tar.gz (255.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

draw_scikit_tree-0.1.3-py3-none-any.whl (8.1 kB view details)

Uploaded Python 3

File details

Details for the file draw_scikit_tree-0.1.3.tar.gz.

File metadata

  • Download URL: draw_scikit_tree-0.1.3.tar.gz
  • Upload date:
  • Size: 255.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.7.13

File hashes

Hashes for draw_scikit_tree-0.1.3.tar.gz
Algorithm Hash digest
SHA256 6a6ec080a109a7d1e8f4df55774ebc889d3a666964f98f9ab81949cdec246731
MD5 dfc2b7fcf98f91875c3c6ea3de090d39
BLAKE2b-256 edd7fba770c212c6e2653f9dc81316e92761ce1d0f3bb44198d4a3fa76a400e4

See more details on using hashes here.

File details

Details for the file draw_scikit_tree-0.1.3-py3-none-any.whl.

File metadata

File hashes

Hashes for draw_scikit_tree-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 f39d3cb13f6de143f0db0774e7b7ff0c7834bb5e16ed3a7e910e428f89a569a0
MD5 8275132998e236ea61fdac5111a41048
BLAKE2b-256 7d140f806dc4e518bb3b679279d39027c472238a09ce5b5f28a49c24f5478cd4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page