Scoffolds out code for decision tree models that can learn to find relationships between the attributes of an object.
Project description
Decision Tree Writer
This package allows you to train a binary classification decision tree on a list of labeled dicts or class instances, and then writes a new .py file with the code for the new decision tree model.
Installation
Simply run py -m pip install decision-tree-writer
from the command line (Windows)
or python3 -m pip install decision-tree-writer
(Unix/macOS) and you're ready to go!
Usage
1) Train the model
Use the DecisionTreeWriter class to train a model on a data set and write the code to a new file in a specified fie folder (default folder is the same as your code):
from decision_tree_writer import DecisionTreeWriter
# Here we're using some of the famous iris data set for an example.
# You could alternatively make an Iris class with the same
# attributes as the keys of each of these dictionaries.
iris_data = [
{ "species": "setosa", "sepal_length": 5.2, "sepal_width": 3.5, "petal_length": 1.5, "petal_width": 0.2},
{ "species": "setosa", "sepal_length": 5.2, "sepal_width": 4.1, "petal_length": 1.5, "petal_width": 0.1},
{ "species": "setosa", "sepal_length": 5.4, "sepal_width": 3.7, "petal_length": 1.5, "petal_width": 0.2},
{ "species": "versicolor", "sepal_length": 6.2, "sepal_width": 2.2, "petal_length": 4.5, "petal_width": 1.5},
{ "species": "versicolor", "sepal_length": 5.7, "sepal_width": 2.9, "petal_length": 4.2, "petal_width": 1.3},
{ "species": "versicolor", "sepal_length": 5.6, "sepal_width": 2.9, "petal_length": 3.6, "petal_width": 1.3},
{ "species": "virginica", "sepal_length": 7.2, "sepal_width": 3.2, "petal_length": 6.0, "petal_width": 1.8},
{ "species": "virginica", "sepal_length": 6.1, "sepal_width": 2.6, "petal_length": 5.6, "petal_width": 1.4},
{ "species": "virginica", "sepal_length": 6.8, "sepal_width": 3.0, "petal_length": 5.5, "petal_width": 2.1},
]
# Create the writer. You must specify which attribute or key is the label of the data items.
# You can also specify the max branching depth of the tree (default [and max] is 998)
# or how many data items there must be to make a new branch (default is 1)
writer = DecisionTreeWriter(label_name="species")
# Trains a new model and saves it to a new .py file
writer.create_tree(iris_data, True, "Iris Classifier")
2) Using the new decision tree
In the specified file folder a new python file with one function will appear. It will have the name you gave your decision tree model plus a uuid to ensure it has a unique name. The generated code will look something like this:
from decision_tree_writer.BaseDecisionTree import *
# class-like syntax because it acts like it's instantiating a class.
def IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d() -> 'BaseDecisionTree':
"""
IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d has been trained to identify the species of a given object.
"""
tree = BaseDecisionTree(None, dict, 'IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d')
tree.root = Branch(lambda x: x['sepal_length'] <= 5.5)
tree.root.l = Leaf('setosa')
tree.root.r = Branch(lambda x: x['petal_length'] <= 5.0)
tree.root.r.l = Leaf('versicolor')
tree.root.r.r = Leaf('virginica')
return tree
Important note: if you train your model with class instance data you will have to import that class in the new file. That might look like:
from decision_tree_writer.BaseDecisionTree import *
from wherever import Iris
def IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d() -> 'BaseDecisionTree':
tree = BaseDecisionTree(None, Iris, 'IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d')
Now just use the factory function to create an instance of the model.
The model has two important methods, classify_one
, which takes a data item of the same type as you trained the model with and returns what it thinks is the correct label for it, and classify_many
, which does the same as the first but with a list of data and returns a list of labels.
Example:
tree = IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d()
print(tree.classify_one({ "sepal_length": 5.4, "sepal_width": 3.2, "petal_length": 1.6, "petal_width": 0.3})) # output: 'setosa'
Bugs or questions
If you find any problems with this package of have any questions, please create an issue on this package's GitHub repo
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for decision-tree-writer-0.1.1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 31510be393fb372ff6be1ae770f2bec59b9f5fa1615d96872664915bdccf879d |
|
MD5 | 33a8822f30f73f01e4a8d7b69157aad1 |
|
BLAKE2b-256 | bc56f102342813bc225c1f42bc73ac65084f6016bfdab8f7f523191561aa097e |
Hashes for decision_tree_writer-0.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4f2abb8d78ed5a8421078e759f022b027bc0eecd5628eced5cc52232e1de026d |
|
MD5 | 3862634cb11749885d9c59676b357e4d |
|
BLAKE2b-256 | 827498b626828cd63f9d43cb311118a6141c0085ed4b814da5ffddcbe63f01c9 |