Skip to main content

Rolling Lookahead Optimal Classification Trees

Project description

RolloTree — Rolling Lookahead Optimal Classification Trees

PyPI version Python License: GPL v3

An implementation of the rolling subtree lookahead algorithm from "Rolling Lookahead Learning for Optimal Classification Trees" (published in IISE Transactions).

RolloTree builds interpretable decision trees by solving a sequence of small mixed-integer programs (MIPs). It starts with an optimal depth-2 tree and iteratively expands misclassified leaves, combining the scalability of greedy methods with the optimality guarantees of MIP-based approaches.

Features

  • Solver-agnostic — uses PuLP so you can choose between:
    • HiGHS (open-source, installed by default)
    • Gurobi (commercial, optional — install gurobipy separately)
    • CBC (open-source, bundled with PuLP)
  • sklearn-style APIfit() / predict() / score()
  • Two impurity criteria — Gini index or misclassification error
  • Arbitrary depth — depth-2 base tree extended via rolling subtree optimization

Installation

pip install rollotree

Or in editable/development mode (from a local clone):

pip install -e ".[dev]"

Dependencies: pulp, highspy, numpy, pandas.

To use the Gurobi solver backend:

pip install "rollotree[gurobi]"

Quick Start

import pandas as pd
from rollotree import RollingOCT

# Load data
train = pd.read_csv("rollotree/data/train.csv")
X_train = train.drop("y", axis=1)
y_train = train["y"]

test = pd.read_csv("rollotree/data/test.csv")
X_test = test.drop("y", axis=1)
y_test = test["y"]

# Train a depth-3 tree using the open-source HiGHS solver
model = RollingOCT(depth=3, criterion="gini", solver="highs")
model.fit(X_train, y_train)

# Evaluate
print(f"Test accuracy: {model.score(X_test, y_test):.3f}")

API Reference

RollingOCT

RollingOCT(
    depth=2,              # Maximum tree depth (>= 2)
    criterion="gini",     # "gini" or "misclassification"
    solver="highs",       # "highs", "gurobi", or "cbc"
    time_limit=1800,      # Max seconds per depth-2 subproblem
    mip_gap=None,         # MIP optimality gap (e.g. 0.01 for 1%)
    big_m=99,             # Penalty for empty-leaf splits
    log_to_console=False,
    min_samples_split=2,  # Min samples to solve a subproblem
    min_samples_leaf=1,   # Min samples per leaf node
)

Methods:

Method Description
fit(X, y) Train the model. X is a binary feature matrix (DataFrame or array), y is the target. Returns self.
predict(X) Return class predictions as a numpy array.
score(X, y) Return accuracy (fraction correct).

Attributes (after fitting):

Attribute Description
tree_ The fitted DecisionTree object
depth_results_ Dict mapping depth → DepthResult(depth, training_accuracy, test_accuracy, elapsed_time)
classes_ Sorted list of unique class labels
features_ List of feature indices used

Solver Options

Solver Install Notes
"highs" pip install highspy (included in requirements) Best open-source MIP solver. Default.
"gurobi" pip install gurobipy + license Fastest commercial solver.
"cbc" Bundled with PuLP Fallback open-source solver.

Examples

See the examples/ directory for Jupyter notebooks:

Dataset

An example dataset is provided under rollotree/data/ — a binarized version of the Wine Dataset (3-class, 130 binary features).

How It Works

  1. OCT-2 Formulation: Solves a MIP to find the optimal depth-2 binary classification tree. Binary variables select which feature to split on at each node, and the objective minimizes total leaf impurity.

  2. Rolling Subtree (RST) Algorithm: Identifies misclassified leaves, groups them by parent, then solves a new OCT-2 subproblem for each parent's data subset. The resulting subtrees are merged back into the main tree. This repeats level-by-level until the target depth is reached.

  3. LP Relaxation Integrality: The OCT-2 formulation's feasible region is an integral polyhedron (Proposition 2 in the paper), so the LP relaxation always yields an integer-optimal solution.

Project Structure

rollotree/
    __init__.py              # Public API: RollingOCT, SolverConfig, etc.
    classifier.py            # RollingOCT (sklearn-like fit/predict)
    tree/
        nodes.py             # DecisionNode, LeafNode, DecisionTree
        impurity.py          # GiniCriterion, MisclassificationCriterion
        utils.py             # Node generation, index mapping helpers
    solver/
        base.py              # SolverStatus, OCT2Solution, SolverConfig
        pulp_solver.py       # PuLPOCT2Solver (HiGHS / Gurobi / CBC)
    rolling/
        optimizer.py         # RollingOptimizer, DepthResult
    preprocessing/
        helpers.py           # Data binarization and preprocessing
    data/
        train.csv, test.csv  # Example Wine dataset
tests/                       # 70 pytest test cases
examples/                    # Jupyter notebooks

Citation

@article{organ2026rolling,
      title={Rolling Lookahead Learning for Optimal Classification Trees},
      author={Zeynel Batuhan Organ and Enis Kayış and Taghi Khaniyev},
      journal={IISE Transactions},
      year={2026},
      publisher={Taylor \& Francis},
      doi={10.1080/24725854.2026.2613786},
      url={https://www.tandfonline.com/doi/abs/10.1080/24725854.2026.2613786}
}

A preprint is also available on arXiv (2304.10830).

Contact

Feel free to reach out with questions or feedback.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rollotree-1.2.0.tar.gz (44.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rollotree-1.2.0-py3-none-any.whl (40.3 kB view details)

Uploaded Python 3

File details

Details for the file rollotree-1.2.0.tar.gz.

File metadata

  • Download URL: rollotree-1.2.0.tar.gz
  • Upload date:
  • Size: 44.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for rollotree-1.2.0.tar.gz
Algorithm Hash digest
SHA256 07aa9d1f1826a41b361fc57d04c000f3c9b688e740e91111954eff1c7a52a4be
MD5 fc006c92835d52afb84d2c1a68688121
BLAKE2b-256 c2b2ca3d803abc63ab11c852f25a14ceeba73ea5658c21b5644740ec6f7ace5a

See more details on using hashes here.

Provenance

The following attestation bundles were made for rollotree-1.2.0.tar.gz:

Publisher: publish.yml on koftezz/rolling-lookahead-dt

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file rollotree-1.2.0-py3-none-any.whl.

File metadata

  • Download URL: rollotree-1.2.0-py3-none-any.whl
  • Upload date:
  • Size: 40.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for rollotree-1.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 71392657509929f31aa8e016424c104635dab158f65ec7e0a52ea7b12afbf05b
MD5 d4a6c849c4f36b6dda6dbb0a39318715
BLAKE2b-256 f9146029f23cb5b21f9b60462640a8efacdec10888f9d0e587c88ba567c63dbe

See more details on using hashes here.

Provenance

The following attestation bundles were made for rollotree-1.2.0-py3-none-any.whl:

Publisher: publish.yml on koftezz/rolling-lookahead-dt

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page