Skip to main content

Python package for the Trinary Tree algorithm

Project description

Trinary Tree

PyPI version License: MIT

The Trinary Tree is a algorithm based on the Classification and Regression Tree (CART). It provides a novel way to handle missing data by assigning missing values to a third node in the originally binary split of the CART. For details on the algorithm, se the arXiV preprint at https://arxiv.org/abs/2309.03561

Installation

You can install the trinary_tree package via pip:

pip install trinary_tree

or via GitHub

pip install git+https://github.com/henningzakrisson/trinary_tree.git

Usage example

Fitting a Trinary Tree and a Binary Tree using the majority rule algorithm to a dataset with missing values.

# Import packages
from src.trinary_tree import BinaryTree, TrinaryTree
from sklearn.model_selection import train_test_split
import numpy as np

# Generate data
rng = np.random.default_rng(seed=11)
X = rng.normal(size=(1000,2))
mu = 10*(X[:,0]>0) + X[:,1]*2
y = rng.normal(mu,1)

# Censor data
censor = rng.choice(np.prod(X.shape), int(0.2*np.prod(X.shape)), replace=False)
X_censored = X.flatten()
X_censored[censor] = np.nan
X_censored = X_censored.reshape(X.shape)

# Train trees
X_train, X_test, y_train, y_test = train_test_split(X_censored,y)
tree_binary = BinaryTree(max_depth = 1)
tree_trinary = TrinaryTree(max_depth = 1)
tree_binary.fit(X_train,y_train)
tree_trinary.fit(X_train,y_train)

# Calculate MSE
mse_binary = np.mean((y_test - tree_binary.predict(X_test))**2)
mse_trinary = np.mean((y_test - tree_trinary.predict(X_test))**2)
print(f"Binary tree MSE: {mse_binary:.3f}")
print(f"Trinary tree MSE: {mse_trinary:.3f}")

Contact

If you have any questions, feel free to contact me here.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trinary_tree-0.1.1.tar.gz (3.0 kB view hashes)

Uploaded Source

Built Distribution

trinary_tree-0.1.1-py3-none-any.whl (2.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page