Skip to main content

No project description provided

Project description

Gradient-based Entire Tree Optimization For Oblique Regression Tree

This repository has been restructured to offer a more organized and user-friendly interface. GET (Gradient-based Entire Tree) is designed to induce oblique decision trees by optimizing the entire tree structure via gradient-based optimization. It supports both regression and classification tasks. For detailed information on the algorithm, please refer to the study “Can a Single Tree Outperform an Entire Forest?”, available at https://arxiv.org/pdf/2411.17003.

Language Dependencies Status

Features in this version:

  • GETRegressor(): An oblique regression tree with constant predictions.
  • GETSubPolRegressor(): An oblique regression tree with constant predictions, enhanced with a subtree polishing strategy.

New features will be added in next versions, including:

  • tree path-based interpretability
  • Classification tree implementations like GETClassifier() and GETSubPolClassifier()

Package Dependencies

  • scikit-learn 1.5.0
  • numpy 1.26.4
  • pandas 2.2.3
  • h5py 3.13.0
  • torch 2.0.0+

Package Installation

pip install get-oblique

Package Description

GETRegressor class: oblique regression tree with constant predictions.

  • Parameters:
    • treeDepth (int, default=4): The depth of the regression tree.
    • epochNum (int, default=3000): Number of training epochs used during optimization.
    • startNum (int, default=10): Number of random initializations for the tree optimization process (This increases the chance of finding optimal solutions).
    • device (str, default='cpu'): The computation device to use: 'cpu' or 'cuda'. Set to 'cuda' to enable GPU acceleration.
  • Methods:
    • fit(X, y):
      Train the model using gradient-based optimization. Automatically moves data to the specified device and converts to float tensors.
    • predict(X):
      Predicts target values based on trained tree structure.

GETSubPolRegressor class: oblique regression tree with constant predictions and subtree polish strategy.

  • Parameters:
    • treeDepth (int, default=4): The depth of the regression tree.
    • epochNum (int, default=3000): Number of training epochs used during optimization.
    • startNum (int, default=10): Number of random initializations for the tree optimization process (This increases the chance of finding optimal solutions).
    • device (str, default='cpu'): The computation device to use: 'cpu' or 'cuda'. Set to 'cuda' to enable GPU acceleration.
  • Methods:
    • fit(X, y):
      Train the model using gradient-based optimization and subtree polish strategy. Automatically moves data to the specified device and converts to float tensors.
    • predict(X):
      Predicts target values based on trained tree structure.

Usage Example

To use the GETRegressor class:

import numpy as np
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from GET import GETRegressor

# Load and prepare dataset
data = fetch_california_housing()

# X, y can be either Numpy arrays or Pytorch tensors, in this case they are numpy arrays
X, y = data.data, data.target

# Split into train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Initialize the model
model = GETRegressor()

# Fit the model
model.fit(X_train, y_train)

# Predict
y_pred = model.predict(X_test)

# Print sample predictions
print("First 10 predicted values:", y_pred[:10])

Others

If you encounter any errors or notice unexpected tree performance, please don't hesitate to contact us.

License

This repository is published under the terms of the GNU General Public License v3.0 .

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

get_oblique-0.1.1.tar.gz (71.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

get_oblique-0.1.1-py3-none-any.whl (72.0 kB view details)

Uploaded Python 3

File details

Details for the file get_oblique-0.1.1.tar.gz.

File metadata

  • Download URL: get_oblique-0.1.1.tar.gz
  • Upload date:
  • Size: 71.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.7

File hashes

Hashes for get_oblique-0.1.1.tar.gz
Algorithm Hash digest
SHA256 00908f99824ed52540c71ccb72e546f671487726bd7f3786b5002502627f3f41
MD5 8239ceb7198b4f9111db3b108b2f9c02
BLAKE2b-256 e34f80e1f538ba452d871046bc8d39e1d97d77ad09aeffc1a5a68a3273be394a

See more details on using hashes here.

File details

Details for the file get_oblique-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: get_oblique-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 72.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.7

File hashes

Hashes for get_oblique-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 8290e2f75a16ecff6650177cf2e541e29173dffc6222b161b3f121ade1feae9e
MD5 9f58f2da3c7278eb0306898f6d7f797b
BLAKE2b-256 e097b45db274b9c8025c7b39165b3e1698df7125889dfc3f9826473f4f4d7576

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page