Skip to main content

Rolling Lookahead Optimal Classification Trees

Project description

RolloTree — Rolling Lookahead Optimal Classification Trees

PyPI version Python License: GPL v3

An implementation of the rolling subtree lookahead algorithm from "Rolling Lookahead Learning for Optimal Classification Trees" (published in IISE Transactions).

RolloTree builds interpretable decision trees by solving a sequence of small mixed-integer programs (MIPs). It starts with an optimal depth-2 tree and iteratively expands misclassified leaves, combining the scalability of greedy methods with the optimality guarantees of MIP-based approaches.

Features

  • Solver-agnostic — uses PuLP so you can choose between:
    • HiGHS (open-source, installed by default)
    • Gurobi (commercial, optional — install gurobipy separately)
    • CBC (open-source, bundled with PuLP)
  • Full sklearn compatibilityfit() / predict() / score() / predict_proba() / get_params() / set_params() — works with GridSearchCV, Pipeline, cross_val_score, and clone()
  • Class probabilitiespredict_proba() returns per-class probabilities from leaf distributions
  • Feature importancesfeature_importances_ based on split frequency across branch nodes
  • Tree visualizationexport_text() for ASCII and export_graphviz() for DOT/Graphviz output
  • Tree inspectionapply(), decision_path(), get_n_leaves(), get_depth()
  • Model persistencesave() / load() via joblib, plus standard pickle support
  • Input validation — clear error messages for non-binary features, single-class targets, and more
  • Two impurity criteria — Gini index or misclassification error
  • Arbitrary depth — depth-2 base tree extended via rolling subtree optimization
  • Parallel solvingn_jobs=-1 to solve independent subproblems across CPU cores

Installation

pip install rollotree

Or in editable/development mode (from a local clone):

pip install -e ".[dev]"

Dependencies: pulp, highspy, numpy, pandas, scipy, joblib.

For faster prediction and tree building with Numba JIT compilation:

pip install "rollotree[fast]"

To use the Gurobi solver backend:

pip install "rollotree[gurobi]"

Quick Start

import pandas as pd
from rollotree import RollingOCT

# Load data
train = pd.read_csv("rollotree/data/train.csv")
X_train = train.drop("y", axis=1)
y_train = train["y"]

test = pd.read_csv("rollotree/data/test.csv")
X_test = test.drop("y", axis=1)
y_test = test["y"]

# Train a depth-3 tree using the open-source HiGHS solver
model = RollingOCT(depth=3, criterion="gini", solver="highs")
model.fit(X_train, y_train)

# Evaluate
print(f"Test accuracy: {model.score(X_test, y_test):.3f}")

# Class probabilities
proba = model.predict_proba(X_test)
print(f"Probabilities shape: {proba.shape}")

# Feature importances
importances = model.feature_importances_
print(f"Top feature: {importances.argmax()}, importance: {importances.max():.3f}")

# Tree visualization
from rollotree import export_text
print(export_text(model.tree_))

sklearn Integration

from sklearn.model_selection import cross_val_score, StratifiedKFold, GridSearchCV

# Cross-validation
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
scores = cross_val_score(model, X_train, y_train, cv=cv)
print(f"CV accuracy: {scores.mean():.3f} (+/- {scores.std():.3f})")

# Grid search
grid = GridSearchCV(
    RollingOCT(solver="highs"),
    {"depth": [2, 3, 4], "criterion": ["gini", "misclassification"]},
    cv=cv,
)
grid.fit(X_train, y_train)
print(f"Best params: {grid.best_params_}")

Model Persistence

# Save and load
model.save("my_tree.joblib")
loaded = RollingOCT.load("my_tree.joblib")

# Standard pickle also works
import pickle
data = pickle.dumps(model)
loaded = pickle.loads(data)

Parallel Execution

For deeper trees on larger datasets, use n_jobs to solve independent subproblems in parallel:

# Use all CPU cores for parallel subproblem solving
model = RollingOCT(depth=5, solver="highs", n_jobs=-1)
model.fit(X_train, y_train)

# Or specify an exact number of workers
model = RollingOCT(depth=5, solver="highs", n_jobs=4)
model.fit(X_train, y_train)

The rolling expansion at each depth level solves independent OCT-2 MIP subproblems for each parent node. With n_jobs > 1 these are dispatched across processes via ProcessPoolExecutor. Speedup scales with the number of parents per level — most effective at depth 4+.

API Reference

RollingOCT

RollingOCT(
    depth=2,              # Maximum tree depth (>= 2)
    criterion="gini",     # "gini" or "misclassification"
    solver="highs",       # "highs", "gurobi", or "cbc"
    time_limit=1800,      # Max seconds per depth-2 subproblem
    mip_gap=None,         # MIP optimality gap (e.g. 0.01 for 1%)
    big_m=99,             # Penalty for empty-leaf splits
    log_to_console=False,
    min_samples_split=2,  # Min samples to solve a subproblem
    min_samples_leaf=1,   # Min samples per leaf node
    n_jobs=1,             # Parallel workers: 1=sequential, -1=all cores
)

Methods:

Method Description
fit(X, y) Train the model. X is a binary feature matrix (DataFrame or array), y is the target. Returns self.
predict(X) Return class predictions as a numpy array.
predict_proba(X) Return class probability estimates (n_samples × n_classes).
score(X, y) Return accuracy (fraction correct).
apply(X) Return leaf node IDs for each sample.
decision_path(X) Return sparse CSR matrix of nodes visited by each sample.
get_n_leaves() Return number of active (non-pruned) leaf nodes.
get_depth() Return the actual depth of the deepest active path.
get_params() Return dict of estimator parameters (sklearn protocol).
set_params(**params) Set estimator parameters (sklearn protocol). Returns self.
save(path) Save the fitted model to a file (joblib).
RollingOCT.load(path) Load a saved model (class method).

Attributes (after fitting):

Attribute Description
tree_ The fitted DecisionTree object
depth_results_ Dict mapping depth → DepthResult(depth, training_accuracy, test_accuracy, elapsed_time)
classes_ Sorted list of unique class labels
features_ List of feature indices used
feature_importances_ Feature importance array (split frequency, normalized to sum to 1)
n_features_in_ Number of features seen during fit()
feature_names_in_ Feature names (when X is a DataFrame)

Visualization functions:

Function Description
export_text(tree, feature_names=None) ASCII tree representation
export_graphviz(tree, feature_names=None, class_names=None) DOT format string for Graphviz

Solver Options

Solver Install Notes
"highs" pip install highspy (included in requirements) Best open-source MIP solver. Default.
"gurobi" pip install gurobipy + license Fastest commercial solver.
"cbc" Bundled with PuLP Fallback open-source solver.

Examples

See the examples/ directory for Jupyter notebooks (all with saved outputs):

  • 01_quickstart.ipynb — Fit/predict/score, predict_proba, feature_importances_
  • 02_visualization.ipynbexport_text, export_graphviz, apply, decision_path, tree inspection
  • 03_sklearn_integration.ipynbGridSearchCV, Pipeline, cross_val_score, model persistence
  • 04_advanced.ipynb — Depth analysis, preprocessing, solver tuning, parallel execution, tree internals, Numba, per-leaf stats, criteria comparison, early stopping

Dataset

An example dataset is provided under rollotree/data/ — a binarized version of the Wine Dataset (3-class, 130 binary features).

How It Works

  1. OCT-2 Formulation: Solves a MIP to find the optimal depth-2 binary classification tree. Binary variables select which feature to split on at each node, and the objective minimizes total leaf impurity.

  2. Rolling Subtree (RST) Algorithm: Identifies misclassified leaves, groups them by parent, then solves a new OCT-2 subproblem for each parent's data subset. The resulting subtrees are merged back into the main tree. This repeats level-by-level until the target depth is reached.

  3. LP Relaxation Integrality: The OCT-2 formulation's feasible region is an integral polyhedron (Proposition 2 in the paper), so the LP relaxation always yields an integer-optimal solution.

Project Structure

rollotree/
    __init__.py              # Public API: RollingOCT, SolverConfig, etc.
    classifier.py            # RollingOCT (sklearn-like fit/predict)
    tree/
        nodes.py             # DecisionNode, LeafNode, DecisionTree
        export.py            # export_text(), export_graphviz()
        impurity.py          # GiniCriterion, MisclassificationCriterion
        utils.py             # Node generation, index mapping helpers
        _numba.py            # Optional Numba-accelerated tree routing
    solver/
        base.py              # SolverStatus, OCT2Solution, SolverConfig
        pulp_solver.py       # PuLPOCT2Solver (HiGHS / Gurobi / CBC)
    rolling/
        optimizer.py         # RollingOptimizer, DepthResult
        parallel.py          # Multiprocessing worker for subproblem solving
    preprocessing/
        helpers.py           # Data binarization and preprocessing
    data/
        train.csv, test.csv  # Example Wine dataset
tests/                       # 129 pytest test cases
benchmarks/                  # Performance benchmarks
examples/                    # Jupyter notebooks

Citation

@article{organ2026rolling,
      title={Rolling Lookahead Learning for Optimal Classification Trees},
      author={Zeynel Batuhan Organ and Enis Kayış and Taghi Khaniyev},
      journal={IISE Transactions},
      year={2026},
      publisher={Taylor \& Francis},
      doi={10.1080/24725854.2026.2613786},
      url={https://www.tandfonline.com/doi/abs/10.1080/24725854.2026.2613786}
}

A preprint is also available on arXiv (2304.10830).

Contact

Feel free to reach out with questions or feedback.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rollotree-2.0.0.tar.gz (56.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rollotree-2.0.0-py3-none-any.whl (48.2 kB view details)

Uploaded Python 3

File details

Details for the file rollotree-2.0.0.tar.gz.

File metadata

  • Download URL: rollotree-2.0.0.tar.gz
  • Upload date:
  • Size: 56.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for rollotree-2.0.0.tar.gz
Algorithm Hash digest
SHA256 bc9c008f994fdf101d9dcc6ceac3667e3ca9d634eddf7d4a294076fb3e32a9d1
MD5 ef33c5bf85b334bee703d441b27dbcdd
BLAKE2b-256 1aacdcaa16636026d353853526a5bae9c8688456a6431d2e14d727aaaf94dd83

See more details on using hashes here.

Provenance

The following attestation bundles were made for rollotree-2.0.0.tar.gz:

Publisher: publish.yml on koftezz/rolling-lookahead-dt

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file rollotree-2.0.0-py3-none-any.whl.

File metadata

  • Download URL: rollotree-2.0.0-py3-none-any.whl
  • Upload date:
  • Size: 48.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for rollotree-2.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c20cc70bd85dc97151885a115aa6e5682270a7be97999a94221d579008387c47
MD5 1ffbf759edf085807d0ecf2b33156960
BLAKE2b-256 997080ca8f3d184cebcd9c27cbc960b67fc23f2fcb5f930003d4e18a05098517

See more details on using hashes here.

Provenance

The following attestation bundles were made for rollotree-2.0.0-py3-none-any.whl:

Publisher: publish.yml on koftezz/rolling-lookahead-dt

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page