Scaffolds out code for decision tree models that can learn to find relationships between multiple attributes of objects.
Project description
Decision Tree Writer
This package allows you to train a binary classification decision tree on a list of labeled dictionaries or class instances, and then writes a new .py file with the code for the new decision tree model.
Installation
Simply run py -m pip install decision-tree-writer
from the command line (Windows)
or python3 -m pip install decision-tree-writer
(Unix/macOS) and you're ready to go!
Usage
Please see the source code for an example program, and also please see the brief tutorial below:
0) Gather training data
Models are trained on a list of labeled dictionaries or objects. The algorithm only looks at attributes/keys that have numeric or Boolean values, so all nested objects or strings are simply ignored. If you train it with dictionaries, all of the items must have the same keys (with numeric or Boolean values; each can have different keys with string or object or whatever else values). Similarly, if you give it a list of objects, they must all have the same attributes (with integer, floating-point, or Boolean values). Finally, all of the items in the data set must have a label attribute/key that has the same name for each item, and can have any value (as shown in this example data set):
# Here we're using some of the famous iris data set for an example.
iris_data = [
{ "species": "setosa", "sepal_length": 5.2, "sepal_width": 3.5,
"petal_length": 1.5, "petal_width": 0.2},
{ "species": "setosa", "sepal_length": 5.2, "sepal_width": 4.1,
"petal_length": 1.5, "petal_width": 0.1},
{ "species": "setosa", "sepal_length": 5.4, "sepal_width": 3.7,
"petal_length": 1.5, "petal_width": 0.2},
{ "species": "versicolor", "sepal_length": 6.2, "sepal_width": 2.2,
"petal_length": 4.5, "petal_width": 1.5},
{ "species": "versicolor", "sepal_length": 5.7, "sepal_width": 2.9,
"petal_length": 4.2, "petal_width": 1.3},
{ "species": "versicolor", "sepal_length": 5.6, "sepal_width": 2.9,
"petal_length": 3.6, "petal_width": 1.3},
{ "species": "virginica", "sepal_length": 7.2, "sepal_width": 3.2,
"petal_length": 6.0, "petal_width": 1.8},
{ "species": "virginica", "sepal_length": 6.1, "sepal_width": 2.6,
"petal_length": 5.6, "petal_width": 1.4},
{ "species": "virginica", "sepal_length": 6.8, "sepal_width": 3.0,
"petal_length": 5.5, "petal_width": 2.1}
]
You could alternatively make an Iris class with the same attributes as the keys of each of these dictionaries:
from dataclasses import dataclass
@dataclass
class Iris:
species: str
sepal_length: float
sepal_width: float
petal_length: float
petal_width: float
# And then instantiate twelve Iris objects with the previous data
1) Train the model
Use the DecisionTreeWriter class to train a model on a data set and write the code to a new file in a specified file folder (default folder is the same as your code) The label of an item in the training data set is a specified attribute or key (in this example, the key "species"):
from decision_tree_writer import DecisionTreeWriter
# Create the writer.
# You must specify which attribute or key is the label of the data items.
# You can also specify the max branching depth of the tree (default [and max] is 998)
# or how many data items there must be to make a new branch (default is 1).
writer = DecisionTreeWriter(label_name="species")
# Trains a new model and saves it to a new .py file.
writer.create_tree(data_set = iris_data,
look_for_correlations = True,
tree_name = "Iris Classifier")
2) Examining the trained decision tree code
In the specified file folder a new python file with one function will appear. It will have the name you gave your decision tree model plus a uuid to ensure it has a unique name. The generated code will look something like this:
from decision_tree_writer.BaseDecisionTree import *
# class-like syntax because it acts like it's instantiating a class.
def Iris_Classifier__0c609d3a_741e_4770_8bce_df246bad054d() -> 'BaseDecisionTree':
"""
Iris_Classifier__0c609d3a_741e_4770_8bce_df246bad054d
has been trained to identify the species of a given dictionary object.
"""
tree = BaseDecisionTree(dict,
'Iris_Classifier__0c609d3a_741e_4770_8bce_df246bad054d')
tree.root = Branch(lambda x: x['sepal_length'] <= 5.5)
tree.root.l = Leaf('setosa')
tree.root.r = Branch(lambda x: x['petal_length'] <= 5.0)
tree.root.r.l = Leaf('versicolor')
tree.root.r.r = Leaf('virginica')
return tree
The trained model is built by adding decision Branches that terminate in classification Leaves to a BaseDecisionTree. When the tree is used to classify a given input, it will hand the input to its root Branch to use the given comparison function to select the left node (root.l) of the root if the comparison evaluates to True, and the right node (root.r) if it is False (in this case, it checks if the iris's sepal_length is less than or equal to 5.5). If the selected node is a Leaf, it returns the label on the Leaf. If the node is a Branch, it runs that Branch's comparison function of the input to select one of its subnodes until a Leaf is reached.
Important note: if you train your model with class instance data you may have to change the import statement for that class in the new file if the class is in a file in a different directory from where you have the model's file placed. The code for a model trained with objects would start like:
from decision_tree_writer.BaseDecisionTree import *
# Please fix this import statement if necessary
from sample_data.flowers import Iris
def Iris_Classifier__0c609d3a_741e_4770_8bce_df246bad054d() -> 'BaseDecisionTree':
tree = BaseDecisionTree(Iris,
'Iris_Classifier__0c609d3a_741e_4770_8bce_df246bad054d')
3) Using the new decision tree
Now just import and call the factory function to create an instance of the trained model.
The model has two important methods, classify_one
, which takes a data item of the same type as you trained the model with and returns what it thinks is the correct label for it, and classify_many
, which does the same as the first but with a list of data and returns a list of labels.
Example:
from IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d import *
tree = IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d()
print(tree.classify_one(
{ "sepal_length": 5.4, "sepal_width": 3.2,
"petal_length": 1.6, "petal_width": 0.3})) # output: setosa
print(tree.classify_many(
[
{ "sepal_length": 5.4, "sepal_width": 3.2,
"petal_length": 1.6, "petal_width": 0.3},
{ "sepal_length": 5.7, "sepal_width": 2.9,
"petal_length": 4.2, "petal_width": 1.3},
]
)) # output: ['setosa', 'versicolor']
Bugs or questions
If you find any problems with this package of have any questions, please create an issue on this package's GitHub repo
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file decision-tree-writer-0.5.1.tar.gz
.
File metadata
- Download URL: decision-tree-writer-0.5.1.tar.gz
- Upload date:
- Size: 14.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.8.1 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9c48a01e890752feac076497b3daa8d41d92a117dff75a468a9cd698e9811ff7 |
|
MD5 | d9a91628985c73a9af17ed6284891e39 |
|
BLAKE2b-256 | 5f1906f50e0ce82ccc63a77a843f2455c1e93fec2e2960679b9d1d98dfe35511 |
File details
Details for the file decision_tree_writer-0.5.1-py3-none-any.whl
.
File metadata
- Download URL: decision_tree_writer-0.5.1-py3-none-any.whl
- Upload date:
- Size: 14.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.8.1 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5460c2c023498e329e2a135e01c7c6bc3ab2e4c5f91a2ac4ec79f4a0edfe1f7f |
|
MD5 | bf701b1d04cd5966acd9e246aecec9f8 |
|
BLAKE2b-256 | 1ac4b5166960d219f53c4e1408b05d41d0bff790f7227768677b4735bce68866 |