Skip to main content

No project description provided

Project description

Gradient-based Entire Tree Optimization For Oblique Regression Tree

This repository has been restructured to offer a more organized and user-friendly interface. GET (Gradient-based Entire Tree) is designed to induce oblique decision trees by optimizing the entire tree structure via gradient-based optimization. It supports both regression and classification tasks. For detailed information on the algorithm, please refer to the study “Can a Single Tree Outperform an Entire Forest?”, available at https://arxiv.org/pdf/2411.17003.

Language Dependencies Status

Features in this version:

  • GETRegressor(): An oblique regression tree with constant predictions.
  • GETSubPolRegressor(): An oblique regression tree with constant predictions, enhanced with a subtree polishing strategy.

New features will be added in next versions, including:

  • tree path-based interpretability
  • Classification tree implementations like GETClassifier() and GETSubPolClassifier()

Package Dependencies

  • scikit-learn 1.5.0
  • numpy 1.26.4
  • pandas 2.2.3
  • h5py 3.13.0
  • torch 2.0.0+

Package Installation

pip install GET

Package Description

GETRegressor class: oblique regression tree with constant predictions.

  • Parameters:
    • treeDepth (int, default=4): The depth of the regression tree.
    • epochNum (int, default=3000): Number of training epochs used during optimization.
    • startNum (int, default=10): Number of random initializations for the tree optimization process (This increases the chance of finding optimal solutions).
    • device (str, default='cpu'): The computation device to use: 'cpu' or 'cuda'. Set to 'cuda' to enable GPU acceleration.
  • Methods:
    • fit(X, y):
      Train the model using gradient-based optimization. Automatically moves data to the specified device and converts to float tensors.
    • predict(X):
      Predicts target values based on trained tree structure.

GETSubPolRegressor class: oblique regression tree with constant predictions and subtree polish strategy.

  • Parameters:
    • treeDepth (int, default=4): The depth of the regression tree.
    • epochNum (int, default=3000): Number of training epochs used during optimization.
    • startNum (int, default=10): Number of random initializations for the tree optimization process (This increases the chance of finding optimal solutions).
    • device (str, default='cpu'): The computation device to use: 'cpu' or 'cuda'. Set to 'cuda' to enable GPU acceleration.
  • Methods:
    • fit(X, y):
      Train the model using gradient-based optimization and subtree polish strategy. Automatically moves data to the specified device and converts to float tensors.
    • predict(X):
      Predicts target values based on trained tree structure.

Usage Example

To use the GETRegressor class:

import numpy as np
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from GET import GETRegressor

# Load and prepare dataset
data = fetch_california_housing()

# X, y can be either Numpy arrays or Pytorch tensors, in this case they are numpy arrays
X, y = data.data, data.target

# Split into train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Initialize the model
model = GETRegressor()

# Fit the model
model.fit(X_train, y_train)

# Predict
y_pred = model.predict(X_test)

# Print sample predictions
print("First 10 predicted values:", y_pred[:10])

Github Repository Link

https://github.com/maoqiangqiang/GET

Others

If you encounter any errors or notice unexpected tree performance, please don't hesitate to contact us.

License

This repository is published under the terms of the GNU General Public License v3.0 .

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

get_oblique-0.1.0.tar.gz (71.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

get_oblique-0.1.0-py3-none-any.whl (72.0 kB view details)

Uploaded Python 3

File details

Details for the file get_oblique-0.1.0.tar.gz.

File metadata

  • Download URL: get_oblique-0.1.0.tar.gz
  • Upload date:
  • Size: 71.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.7

File hashes

Hashes for get_oblique-0.1.0.tar.gz
Algorithm Hash digest
SHA256 e3c90fc12880a819bb9bc8c718e120c3781800db37e2d38a81cce92cc3a56c42
MD5 709b767f19d0c93276270d28d749bbea
BLAKE2b-256 3c49cfcb5fcb116c1d6f8b2f179e8abb845fa647e6e5964e3c6b05b761a647ec

See more details on using hashes here.

File details

Details for the file get_oblique-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: get_oblique-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 72.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.7

File hashes

Hashes for get_oblique-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d289672e6d7400d89bcbf96381ae74a1b0db61a98949240645d893d47b3184a9
MD5 2dc355a9912ada857295129d816fe257
BLAKE2b-256 cf51f902e6492a22c8360a89e79865d82ca55cbc8c0eaf53923630fba460733e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page