Scaffolds out code for decision tree models that can learn to find relationships between the attributes of an object.
Project description
Decision Tree Writer
This package allows you to train a binary classification decision tree on a list of labeled dictionaries or class instances, and then writes a new .py file with the code for the new decision tree model.
Installation
Simply run py -m pip install decision-tree-writer
from the command line (Windows)
or python3 -m pip install decision-tree-writer
(Unix/macOS) and you're ready to go!
Usage
1) Train the model
Use the DecisionTreeWriter class to train a model on a data set and write the code to a new file in a specified fie folder (default folder is the same as your code):
from decision_tree_writer import DecisionTreeWriter
# Here we're using some of the famous iris data set for an example.
# You could alternatively make an Iris class with the same
# attributes as the keys of each of these dictionaries.
iris_data = [
{ "species": "setosa", "sepal_length": 5.2, "sepal_width": 3.5,
"petal_length": 1.5, "petal_width": 0.2},
{ "species": "setosa", "sepal_length": 5.2, "sepal_width": 4.1,
"petal_length": 1.5, "petal_width": 0.1},
{ "species": "setosa", "sepal_length": 5.4, "sepal_width": 3.7,
"petal_length": 1.5, "petal_width": 0.2},
{ "species": "versicolor", "sepal_length": 6.2, "sepal_width": 2.2,
"petal_length": 4.5, "petal_width": 1.5},
{ "species": "versicolor", "sepal_length": 5.7, "sepal_width": 2.9,
"petal_length": 4.2, "petal_width": 1.3},
{ "species": "versicolor", "sepal_length": 5.6, "sepal_width": 2.9,
"petal_length": 3.6, "petal_width": 1.3},
{ "species": "virginica", "sepal_length": 7.2, "sepal_width": 3.2,
"petal_length": 6.0, "petal_width": 1.8},
{ "species": "virginica", "sepal_length": 6.1, "sepal_width": 2.6,
"petal_length": 5.6, "petal_width": 1.4},
{ "species": "virginica", "sepal_length": 6.8, "sepal_width": 3.0,
"petal_length": 5.5, "petal_width": 2.1}
]
# Create the writer. You must specify which attribute or key is the label of the data items.
# You can also specify the max branching depth of the tree (default [and max] is 998)
# or how many data items there must be to make a new branch (default is 1)
writer = DecisionTreeWriter(label_name="species")
# Trains a new model and saves it to a new .py file
writer.create_tree(iris_data, True, "Iris Classifier")
2) Using the new decision tree
In the specified file folder a new python file with one function will appear. It will have the name you gave your decision tree model plus a uuid to ensure it has a unique name. The generated code will look something like this:
from decision_tree_writer.BaseDecisionTree import *
# class-like syntax because it acts like it's instantiating a class.
def IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d() -> 'BaseDecisionTree':
"""
IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d
has been trained to identify the species of a given object.
"""
tree = BaseDecisionTree(None, dict,
'IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d')
tree.root = Branch(lambda x: x['sepal_length'] <= 5.5)
tree.root.l = Leaf('setosa')
tree.root.r = Branch(lambda x: x['petal_length'] <= 5.0)
tree.root.r.l = Leaf('versicolor')
tree.root.r.r = Leaf('virginica')
return tree
Important note: if you train your model with class instance data you will have to import that class in the new file. That might look like:
from decision_tree_writer.BaseDecisionTree import *
from wherever import Iris
def IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d() -> 'BaseDecisionTree':
tree = BaseDecisionTree(None, Iris,
'IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d')
Now just use the factory function to create an instance of the model.
The model has two important methods, classify_one
, which takes a data item of the same type as you trained the model with and returns what it thinks is the correct label for it, and classify_many
, which does the same as the first but with a list of data and returns a list of labels.
Example:
tree = IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d()
print(tree.classify_one(
{ "sepal_length": 5.4, "sepal_width": 3.2,
"petal_length": 1.6, "petal_width": 0.3})) # output: 'setosa'
Bugs or questions
If you find any problems with this package of have any questions, please create an issue on this package's GitHub repo
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for decision-tree-writer-0.1.3.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | cddfd62ac113e0f626d5e31f372d010f3f011690a614a1b57ffc1b39982fa48f |
|
MD5 | c5ef442b86accd420af37b041387b2ee |
|
BLAKE2b-256 | ad075e0af1c56ac167f56756d5b7e59aae2fc7825111cf3e8a62c957a90fb02b |
Hashes for decision_tree_writer-0.1.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3c3b514cc2961ac1c73ca6df7c8b0750ab5163eb0cef7720d3260caa52e4517c |
|
MD5 | 7bb0f3b9bc6e1ad4e0b32154059302b0 |
|
BLAKE2b-256 | 648ab3bb4ae440681836e469c688db32027c485e57d1493c7d69e9fd0a4d2395 |