Skip to main content

ZTree is a Python wrapper around a Java decision tree algorithm.

Project description

ZTree

ZTree is a Python wrapper around a Java-based decision tree algorithm using JPype.

Features

  • Classification and regression via Java backend.
  • Classification and regression tasks will be auto-detected via the inputted target array
  • tree = ZTree(feature_names=col_names, z_thresh=0.2) to insantiate a model.
  • If no feauture names provided, dummy names will be made. If no z_thresh provided, it will default to 0.5.
  • tree.fit(X, y) to fit a decision tree with X as the instance array and y as the target array.
  • tree.optimal_fit(X, y) to fit a decision tree with optimized z_thresh.
  • tree.search_optimal_z_thresh(X, y) to return the optimal z_thresh.
  • tree.predict(X) returns the predicted class labels for the input features X.
  • tree.predict_proba(X) returns the predicted class probabilities for classification tasks. For each input sample in X, it returns the probability of belonging to each class (e.g., [0.25, 0.75] for class 0 and class 1).
  • print_tree() to print out the trained tree
  • more features to be added in the future, i.e. JSON

Installation

Install via pip: bash pip install ztree

Requirements

To use ZTree, ensure the following dependencies are available in your environment.

Python

  • Python >= 3.8

Python Dependencies

These will be installed automatically when using pip install ztree:

  • numpy – Array and numerical operations
  • scikit-learn – ML utilities and estimator API
  • jpype1 – Bridge to call Java from Python

Java

  • Java Development Kit (JDK) 8 or higher
  • Make sure java is available in your system PATH or set JAVA_HOME

Note: JPype requires a working Java installation to start the JVM from Python.

ZTree Parameters

  • z_thresh: float
    Default = 0.5. The Z-statistic threshold used for feature evaluation.

  • feature_names: list[str]
    Optional. Used to map column names to features in Java.

Basic Usage / Quick start

# basic imports
import pandas as pd
import numpy as np
import zipfile
from sklearn.metrics import roc_auc_score
from ztree import ZTree
from sklearn.model_selection import train_test_split

# import dataset, UCI Adult Income for example
zip_path = "C:/Users/ericr/Downloads/adult.zip"
with zipfile.ZipFile(zip_path, 'r') as z:
    with z.open("adult.data") as f1:
        df1 = pd.read_csv(f1, header=None, sep=",", skipinitialspace=True)
    with z.open("adult.test") as f2:
        df2 = pd.read_csv(f2, header=None, sep=",", skipinitialspace=True, skiprows=1)
datafile = pd.concat([df1, df2], ignore_index=True)

# create column names for printing / preprocessing
col_names = ["age", "workclass", "fnlwgt", "education", "education-num", "marital-status", "occupation", "relationship",
             "race", "sex", "capital-gain", "capital-loss",  "hours-per-week", "native-country", "income"]
datafile.columns = col_names
datafile['income'] = datafile['income'].apply(lambda x: 1 if x in {'>50K', '>50K.'} else 0)
datafile = datafile.drop(columns=["fnlwgt"])
col_names.remove("fnlwgt")
# notice one-hot encoding is not necessary, force categorical features to strings and continuous features leave as floats/ints
X = datafile.drop(columns=["income"])
for col in X.select_dtypes(include=['object', 'category']).columns:
    X[col] = X[col].astype(str) 
col_names = list(X.columns)
y = datafile["income"]
X = X.values
y = y.values
y = y.astype("int64")
# train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.7)
# make tree
ztree1 = ZTree(feature_names=col_names, z_thresh=0.5)
ztree1.fit(X_train, y_train)
y_pred1 = ztree1.predict_proba(X_test)[:, 1]
auc = roc_auc_score(y_test, y_pred1)
print(f"{auc:.6f}")
ztree1.print_tree()
ztree2 = ZTree(feature_names=col_names)
ztree2.fit_optimal(X_train, y_train)
y_pred2 = ztree2.predict_proba(X_test)[:, 1]
auc = roc_auc_score(y_test, y_pred2)
print(f"{auc:.6f}")
ztree2.print_tree()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ztree-0.1.1.tar.gz (7.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ztree-0.1.1-py3-none-any.whl (5.1 kB view details)

Uploaded Python 3

File details

Details for the file ztree-0.1.1.tar.gz.

File metadata

  • Download URL: ztree-0.1.1.tar.gz
  • Upload date:
  • Size: 7.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.5

File hashes

Hashes for ztree-0.1.1.tar.gz
Algorithm Hash digest
SHA256 bec26d6295af8db73b264abff04c8b68db2135f514bc60d3af8490ea4dbdc536
MD5 82b82dbee35a0caa08712aa3f343afd6
BLAKE2b-256 9985bf1c6c753cfbedebc6942c90135d76d027a19f3f3d7cca922a5ee03d031f

See more details on using hashes here.

File details

Details for the file ztree-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: ztree-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 5.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.5

File hashes

Hashes for ztree-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 77c371081c630201687d9d02771be67410561e53238fb92777f32d34c2cf46b6
MD5 7028988d471dfe904ab8177b032a92c3
BLAKE2b-256 91d98178a465fe3e50fcee4537f353f9f6ed62e87f65dede4d3522bace951927

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page