Skip to main content

Measure Inducing Classification and Regression Trees

Project description

muCART - Measure Inducing Classification and Regression Trees

muCART is a Python package that implements Measure Inducing Classification and Regression Trees for Functional Data.

The estimators are implemented with the familiar fit/predict/score interface, and also support multiple predictors of possibly different lengths (as a List of np.ndarray objects, one for each predictor). The following tasks are supported, based on the loss function inside each node of the tree:

  • Regression (mse, mae)
  • Binary and Multiclass Classification (gini, misclassification error, entropy)

A custom cross-validation object is provided in order to perform grid search hyperparameter tuning (with any splitter from scikit-learn), and uses multiprocessing for parallelization (default n_jobs = -1).

Installation

The package can be installed from terminal with the command pip install muCART. Inside each node of the tree, the optimization problems (quadratic with equality and/or inequality constraints) are formulated using Pyomo, which in turn needs a solver to interface with. All the code was tested on Ubuntu using the solver Ipopt. You just need to download the executable binary, and then add the folder that contains it to your path.

Usage

The following lines show how to fit an estimator with its own parameters and grid search object, by using a StratifiedKFold splitter:

import numpy as np
import muCART.grid_search as gs
from muCART.mu_cart import muCARTClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.datasets import load_wine
from sklearn.metrics import balanced_accuracy_score

X, Y = load_wine(return_X_y = True)
train_index = [i for i in range(100)]
test_index = [i for i in range(100, len(X))]
# wrap the single predictor in a List
X = [X]

min_samples_leaf_list = [i for i in range(1,5)]
lambda_list = np.logspace(-5, 5, 10, base = 2)
solver_options = {'solver':'ipopt',
                  'max_iter':500}

estimator = muCARTClassifier(solver_options)
parameters = {'min_samples_leaf':min_samples_leaf_list,
              'lambda':lambda_list,
              'max_depth': [None]}
cv = StratifiedKFold(n_splits = 2,
                     random_state = 46,
                     shuffle = True)
grid_search = gs.GridSearchCV(estimator,
                              parameters,
                              cv,
                              scoring = balanced_accuracy_score,
                              verbose = False,
                              n_jobs = -1)
# extract train samples for each predictor
X_train = [X[i][train_index] for i in range(len(X))]
grid_search.fit(X_train,
                Y[train_index])
# extract test samples for each predictor
X_test = [X[i][test_index] for i in range(len(X))]
score = grid_search.score(X_test,
                          Y[test_index])

The test folder in the github repo contains two sample scripts that show how to use the estimator in both classification and regression tasks. Regarding the scoring, both estimators and the grid search class use accuracy/R^2 as default scores (when the argument scoring = None), but you can provide any Callable scoring function found in sklearn.metrics. Beware that higher is better, and therefore when scoring with errors like sklearn.metrics.mean_squared_error, you need to wrap that in a custom function that changes its sign.

Citing

The code published in this package has been used in the case studies of this paper.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

muCART-1.0.2.tar.gz (13.9 kB view details)

Uploaded Source

File details

Details for the file muCART-1.0.2.tar.gz.

File metadata

  • Download URL: muCART-1.0.2.tar.gz
  • Upload date:
  • Size: 13.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.6

File hashes

Hashes for muCART-1.0.2.tar.gz
Algorithm Hash digest
SHA256 0082f20b778870e6daaf76e288aaf858a4fbbca69319c45a015b4a8983f89287
MD5 881d8862146662c4a08ec1d673b1a377
BLAKE2b-256 28e4897e76cc9c0896134f4a4f019cbd2634dcec32068e9c5edf4abf0b0ea3dd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page