Skip to main content

Rolling Lookahead Optimal Classification Trees

Project description

RolloTree — Rolling Lookahead Optimal Classification Trees

PyPI version Python License: GPL v3

An implementation of the rolling subtree lookahead algorithm from "Rolling Lookahead Learning for Optimal Classification Trees" (published in IISE Transactions).

RolloTree builds interpretable decision trees by solving a sequence of small mixed-integer programs (MIPs). It starts with an optimal depth-2 tree and iteratively expands misclassified leaves, combining the scalability of greedy methods with the optimality guarantees of MIP-based approaches.

Features

  • Solver-agnostic — uses PuLP so you can choose between:
    • HiGHS (open-source, installed by default)
    • Gurobi (commercial, optional — install gurobipy separately)
    • CBC (open-source, bundled with PuLP)
  • sklearn-style APIfit() / predict() / score()
  • Two impurity criteria — Gini index or misclassification error
  • Arbitrary depth — depth-2 base tree extended via rolling subtree optimization

Installation

pip install rollotree

Or in editable/development mode (from a local clone):

pip install -e ".[dev]"

Dependencies: pulp, highspy, numpy, pandas.

To use the Gurobi solver backend:

pip install "rollotree[gurobi]"

Quick Start

import pandas as pd
from rollotree import RollingOCT

# Load data
train = pd.read_csv("rollotree/data/train.csv")
X_train = train.drop("y", axis=1)
y_train = train["y"]

test = pd.read_csv("rollotree/data/test.csv")
X_test = test.drop("y", axis=1)
y_test = test["y"]

# Train a depth-3 tree using the open-source HiGHS solver
model = RollingOCT(depth=3, criterion="gini", solver="highs")
model.fit(X_train, y_train)

# Evaluate
print(f"Test accuracy: {model.score(X_test, y_test):.3f}")

API Reference

RollingOCT

RollingOCT(
    depth=2,              # Maximum tree depth (>= 2)
    criterion="gini",     # "gini" or "misclassification"
    solver="highs",       # "highs", "gurobi", or "cbc"
    time_limit=1800,      # Max seconds per depth-2 subproblem
    mip_gap=None,         # MIP optimality gap (e.g. 0.01 for 1%)
    big_m=99,             # Penalty for empty-leaf splits
    log_to_console=False,
    min_samples_split=2,  # Min samples to solve a subproblem
    min_samples_leaf=1,   # Min samples per leaf node
)

Methods:

Method Description
fit(X, y) Train the model. X is a binary feature matrix (DataFrame or array), y is the target. Returns self.
predict(X) Return class predictions as a numpy array.
score(X, y) Return accuracy (fraction correct).

Attributes (after fitting):

Attribute Description
tree_ The fitted DecisionTree object
depth_results_ Dict mapping depth → DepthResult(depth, training_accuracy, test_accuracy, elapsed_time)
classes_ Sorted list of unique class labels
features_ List of feature indices used

Solver Options

Solver Install Notes
"highs" pip install highspy (included in requirements) Best open-source MIP solver. Default.
"gurobi" pip install gurobipy + license Fastest commercial solver.
"cbc" Bundled with PuLP Fallback open-source solver.

Examples

See the examples/ directory for Jupyter notebooks:

Dataset

An example dataset is provided under rollotree/data/ — a binarized version of the Wine Dataset (3-class, 130 binary features).

How It Works

  1. OCT-2 Formulation: Solves a MIP to find the optimal depth-2 binary classification tree. Binary variables select which feature to split on at each node, and the objective minimizes total leaf impurity.

  2. Rolling Subtree (RST) Algorithm: Identifies misclassified leaves, groups them by parent, then solves a new OCT-2 subproblem for each parent's data subset. The resulting subtrees are merged back into the main tree. This repeats level-by-level until the target depth is reached.

  3. LP Relaxation Integrality: The OCT-2 formulation's feasible region is an integral polyhedron (Proposition 2 in the paper), so the LP relaxation always yields an integer-optimal solution.

Project Structure

rollotree/
    __init__.py              # Public API: RollingOCT, SolverConfig, etc.
    classifier.py            # RollingOCT (sklearn-like fit/predict)
    tree/
        nodes.py             # DecisionNode, LeafNode, DecisionTree
        impurity.py          # GiniCriterion, MisclassificationCriterion
        utils.py             # Node generation, index mapping helpers
    solver/
        base.py              # SolverStatus, OCT2Solution, SolverConfig
        pulp_solver.py       # PuLPOCT2Solver (HiGHS / Gurobi / CBC)
    rolling/
        optimizer.py         # RollingOptimizer, DepthResult
    preprocessing/
        helpers.py           # Data binarization and preprocessing
    data/
        train.csv, test.csv  # Example Wine dataset
tests/                       # 70 pytest test cases
examples/                    # Jupyter notebooks

Citation

@article{organ2026rolling,
      title={Rolling Lookahead Learning for Optimal Classification Trees},
      author={Zeynel Batuhan Organ and Enis Kayış and Taghi Khaniyev},
      journal={IISE Transactions},
      year={2026},
      publisher={Taylor \& Francis},
      doi={10.1080/24725854.2026.2613786},
      url={https://www.tandfonline.com/doi/abs/10.1080/24725854.2026.2613786}
}

A preprint is also available on arXiv (2304.10830).

Contact

Feel free to reach out with questions or feedback.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rollotree-0.1.0.tar.gz (42.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rollotree-0.1.0-py3-none-any.whl (37.5 kB view details)

Uploaded Python 3

File details

Details for the file rollotree-0.1.0.tar.gz.

File metadata

  • Download URL: rollotree-0.1.0.tar.gz
  • Upload date:
  • Size: 42.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for rollotree-0.1.0.tar.gz
Algorithm Hash digest
SHA256 8c0d7590f0e134042461dd5b7affeaf80376b42b83272eea70b006cfed3f8d19
MD5 d3e6970eb87c0b4a06560f4ee04593ce
BLAKE2b-256 f59cab3212ee78fd2a0e4e58fd8fceec2e43479c4628c9ef6d0c4f9e4d38626e

See more details on using hashes here.

File details

Details for the file rollotree-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: rollotree-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 37.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for rollotree-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 af3b1132df096e35dab72d560918e747f82b747a5a3144b231ea1e4e9e4f7e9e
MD5 61996580a242190475452191563dfae1
BLAKE2b-256 0ae7a9b86691cdb78801bf6c547bcf0500687fb8e055e187531f649c755b709e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page