Skip to main content

Python package for the Trinary Tree algorithm

Project description

Trinary Tree

PyPI version License: MIT

The Trinary Tree is a algorithm based on the Classification and Regression Tree (CART). It provides a novel way to handle missing data by assigning missing values to a third node in the originally binary split of the CART. For details on the algorithm, se the arXiV preprint at https://arxiv.org/abs/2309.03561

Note that this code is in no way optimized for speed and training of the trees takes a lot of time compared to other tree packages. The package is a proof-of-concept and it is recommended to re-implement the algorithm if it is to be used in settings where computational speed matters.

Installation

You can install the trinary_tree package via pip:

pip install trinary_tree

or via GitHub

pip install git+https://github.com/henningzakrisson/trinary_tree.git

Usage example

Fitting a Trinary Tree and a Binary Tree using the majority rule algorithm to a dataset with missing values.

# Import packages
from trinary_tree import BinaryTree, TrinaryTree
from sklearn.model_selection import train_test_split
import numpy as np

# Generate data
rng = np.random.default_rng(seed=11)
X = rng.normal(size=(1000,2))
mu = 10*(X[:,0]>0) + X[:,1]*2
y = rng.normal(mu,1)

# Censor data
censor = rng.choice(np.prod(X.shape), int(0.2*np.prod(X.shape)), replace=False)
X_censored = X.flatten()
X_censored[censor] = np.nan
X_censored = X_censored.reshape(X.shape)

# Train trees
X_train, X_test, y_train, y_test = train_test_split(X_censored,y)
tree_binary = BinaryTree(max_depth = 1)
tree_trinary = TrinaryTree(max_depth = 1)
tree_binary.fit(X_train,y_train)
tree_trinary.fit(X_train,y_train)

# Calculate MSE
mse_binary = np.mean((y_test - tree_binary.predict(X_test))**2)
mse_trinary = np.mean((y_test - tree_trinary.predict(X_test))**2)
print(f"Binary tree MSE: {mse_binary:.3f}")
print(f"Trinary tree MSE: {mse_trinary:.3f}")

Contact

If you have any questions, feel free to contact me here.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trinary_tree-0.1.23.tar.gz (12.4 kB view hashes)

Uploaded Source

Built Distribution

trinary_tree-0.1.23-py3-none-any.whl (20.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page