Skip to main content

Scaffolds out code for decision tree models that can learn to find relationships between multiple attributes of objects.

Project description

Decision Tree Writer

PyPI Latest Release License

This package allows you to train a binary classification decision tree on a list of labeled dictionaries or class instances, and then writes a new .py file with the code for the new decision tree model.

Installation

Simply run py -m pip install decision-tree-writer from the command line (Windows) or python3 -m pip install decision-tree-writer (Unix/macOS) and you're ready to go!

Usage

Please see the source code for an example program, and also please see the brief tutorial below:

0) Gather training data

Models are trained on a list of labeled dictionaries or objects. The algorithm only looks at attributes/keys that have numeric or Boolean values, so all nested objects or strings are simply ignored. If you train it with dictionaries, all of the items must have the same keys (with numeric or Boolean values; each can have different keys with string or object or whatever else values). Similarly, if you give it a list of objects, they must all have the same attributes (with integer, floating-point, or Boolean values). Finally, all of the items in the data set must have a label attribute/key that has the same name for each item, and can have any value (as shown in this example data set):

# Here we're using some of the famous iris data set for an example.
iris_data = [
    { "species": "setosa", "sepal_length": 5.2, "sepal_width": 3.5, 
                            "petal_length": 1.5, "petal_width": 0.2},
    { "species": "setosa", "sepal_length": 5.2, "sepal_width": 4.1, 
                            "petal_length": 1.5, "petal_width": 0.1},
    { "species": "setosa", "sepal_length": 5.4, "sepal_width": 3.7, 
                            "petal_length": 1.5, "petal_width": 0.2},
    { "species": "versicolor", "sepal_length": 6.2, "sepal_width": 2.2, 
                            "petal_length": 4.5, "petal_width": 1.5},
    { "species": "versicolor", "sepal_length": 5.7, "sepal_width": 2.9, 
                            "petal_length": 4.2, "petal_width": 1.3},
    { "species": "versicolor", "sepal_length": 5.6, "sepal_width": 2.9, 
                            "petal_length": 3.6, "petal_width": 1.3},
    { "species": "virginica", "sepal_length": 7.2, "sepal_width": 3.2, 
                            "petal_length": 6.0, "petal_width": 1.8},
    { "species": "virginica", "sepal_length": 6.1, "sepal_width": 2.6, 
                            "petal_length": 5.6, "petal_width": 1.4},
    { "species": "virginica", "sepal_length": 6.8, "sepal_width": 3.0, 
                            "petal_length": 5.5, "petal_width": 2.1}
    ]

You could alternatively make an Iris class with the same attributes as the keys of each of these dictionaries:

from dataclasses import dataclass
@dataclass
class Iris:
    species: str
    sepal_length: float
    sepal_width: float
    petal_length: float
    petal_width: float
# And then instantiate twelve Iris objects with the previous data

1) Train the model

Use the DecisionTreeWriter class to train a model on a data set and write the code to a new file in a specified file folder (default folder is the same as your code) The label of an item in the training data set is a specified attribute or key (in this example, the key "species"):

from decision_tree_writer import DecisionTreeWriter

# Create the writer. 
# You must specify which attribute or key is the label of the data items.
# You can also specify the max branching depth of the tree (default [and max] is 998)
# or how many data items there must be to make a new branch (default is 1).
writer = DecisionTreeWriter(label_name="species")

# Trains a new model and saves it to a new .py file.
writer.create_tree(data_set = iris_data, 
                   look_for_correlations = True, 
                   tree_name = "Iris Classifier")

2) Examining the trained decision tree code

In the specified file folder a new python file with one function will appear. It will have the name you gave your decision tree model plus a uuid to ensure it has a unique name. The generated code will look something like this:

from decision_tree_writer.BaseDecisionTree import *

# class-like syntax because it acts like it's instantiating a class.
def Iris_Classifier__0c609d3a_741e_4770_8bce_df246bad054d() -> 'BaseDecisionTree':
    """
    Iris_Classifier__0c609d3a_741e_4770_8bce_df246bad054d 
    has been trained to identify the species of a given dictionary object.
    """
    tree = BaseDecisionTree(dict,
            'Iris_Classifier__0c609d3a_741e_4770_8bce_df246bad054d')
    tree.root = Branch(lambda x: x['sepal_length'] <= 5.5)
    tree.root.l = Leaf('setosa')
    tree.root.r = Branch(lambda x: x['petal_length'] <= 5.0)
    tree.root.r.l = Leaf('versicolor')
    tree.root.r.r = Leaf('virginica')
    
    return tree

The trained model is built by adding decision Branches that terminate in classification Leaves to a BaseDecisionTree. When the tree is used to classify a given input, it will hand the input to its root Branch to use the given comparison function to select the left node (root.l) of the root if the comparison evaluates to True, and the right node (root.r) if it is False (in this case, it checks if the iris's sepal_length is less than or equal to 5.5). If the selected node is a Leaf, it returns the label on the Leaf. If the node is a Branch, it runs that Branch's comparison function of the input to select one of its subnodes until a Leaf is reached.

Important note: if you train your model with class instance data you may have to change the import statement for that class in the new file if the class is in a file in a different directory from where you have the model's file placed. The code for a model trained with objects would start like:

from decision_tree_writer.BaseDecisionTree import *

# Please fix this import statement if necessary
from sample_data.flowers import Iris

def Iris_Classifier__0c609d3a_741e_4770_8bce_df246bad054d() -> 'BaseDecisionTree':
    tree = BaseDecisionTree(Iris, 
                'Iris_Classifier__0c609d3a_741e_4770_8bce_df246bad054d')

3) Using the new decision tree

Now just import and call the factory function to create an instance of the trained model. The model has two important methods, classify_one, which takes a data item of the same type as you trained the model with and returns what it thinks is the correct label for it, and classify_many, which does the same as the first but with a list of data and returns a list of labels.

Example:

from IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d import *
tree = IrisClassifier__0c609d3a_741e_4770_8bce_df246bad054d()
print(tree.classify_one(
            { "sepal_length": 5.4, "sepal_width": 3.2, 
                "petal_length": 1.6, "petal_width": 0.3})) # output: setosa

print(tree.classify_many(
    [
        { "sepal_length": 5.4, "sepal_width": 3.2, 
            "petal_length": 1.6, "petal_width": 0.3},
        { "sepal_length": 5.7, "sepal_width": 2.9, 
            "petal_length": 4.2, "petal_width": 1.3},
    ]
)) # output: ['setosa', 'versicolor']

Bugs or questions

If you find any problems with this package of have any questions, please create an issue on this package's GitHub repo

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

decision-tree-writer-0.5.1.tar.gz (14.7 kB view details)

Uploaded Source

Built Distribution

decision_tree_writer-0.5.1-py3-none-any.whl (14.6 kB view details)

Uploaded Python 3

File details

Details for the file decision-tree-writer-0.5.1.tar.gz.

File metadata

  • Download URL: decision-tree-writer-0.5.1.tar.gz
  • Upload date:
  • Size: 14.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.8.1 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for decision-tree-writer-0.5.1.tar.gz
Algorithm Hash digest
SHA256 9c48a01e890752feac076497b3daa8d41d92a117dff75a468a9cd698e9811ff7
MD5 d9a91628985c73a9af17ed6284891e39
BLAKE2b-256 5f1906f50e0ce82ccc63a77a843f2455c1e93fec2e2960679b9d1d98dfe35511

See more details on using hashes here.

File details

Details for the file decision_tree_writer-0.5.1-py3-none-any.whl.

File metadata

  • Download URL: decision_tree_writer-0.5.1-py3-none-any.whl
  • Upload date:
  • Size: 14.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.8.1 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for decision_tree_writer-0.5.1-py3-none-any.whl
Algorithm Hash digest
SHA256 5460c2c023498e329e2a135e01c7c6bc3ab2e4c5f91a2ac4ec79f4a0edfe1f7f
MD5 bf701b1d04cd5966acd9e246aecec9f8
BLAKE2b-256 1ac4b5166960d219f53c4e1408b05d41d0bff790f7227768677b4735bce68866

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page