Python package for the Trinary Tree algorithm
Project description
Trinary Tree
The Trinary Tree is a algorithm based on the Classification and Regression Tree (CART). It provides a novel way to handle missing data by assigning missing values to a third node in the originally binary split of the CART. For details on the algorithm, se the arXiV preprint at https://arxiv.org/abs/2309.03561
Note that this code is in no way optimized for speed and training of the trees takes a lot of time compared to other tree packages. The package is a proof-of-concept and it is recommended to re-implement the algorithm if it is to be used in settings where computational speed matters.
Installation
You can install the trinary_tree
package via pip:
pip install trinary_tree
or via GitHub
pip install git+https://github.com/henningzakrisson/trinary_tree.git
Usage example
Fitting a Trinary Tree and a Binary Tree using the majority rule algorithm to a dataset with missing values.
# Import packages
from trinary_tree import BinaryTree, TrinaryTree
from sklearn.model_selection import train_test_split
import numpy as np
# Generate data
rng = np.random.default_rng(seed=11)
X = rng.normal(size=(1000,2))
mu = 10*(X[:,0]>0) + X[:,1]*2
y = rng.normal(mu,1)
# Censor data
censor = rng.choice(np.prod(X.shape), int(0.2*np.prod(X.shape)), replace=False)
X_censored = X.flatten()
X_censored[censor] = np.nan
X_censored = X_censored.reshape(X.shape)
# Train trees
X_train, X_test, y_train, y_test = train_test_split(X_censored,y)
tree_binary = BinaryTree(max_depth = 1)
tree_trinary = TrinaryTree(max_depth = 1)
tree_binary.fit(X_train,y_train)
tree_trinary.fit(X_train,y_train)
# Calculate MSE
mse_binary = np.mean((y_test - tree_binary.predict(X_test))**2)
mse_trinary = np.mean((y_test - tree_trinary.predict(X_test))**2)
print(f"Binary tree MSE: {mse_binary:.3f}")
print(f"Trinary tree MSE: {mse_trinary:.3f}")
Contact
If you have any questions, feel free to contact me here.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for trinary_tree-0.1.23-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 98ddfa7c6e46a6efd35e16cec6cca2910f06e9bf249830f01b6ed8c537f67679 |
|
MD5 | 26786246785aede40b561983e63bee64 |
|
BLAKE2b-256 | c7236d899cc853fac49055bfdcbc590b5939c3f5c9ea5f3af0d90fc18dc8a932 |