Skip to main content

Rolling Lookahead Optimal Classification Trees

Project description

RolloTree — Rolling Lookahead Optimal Classification Trees

PyPI version Python License: GPL v3

An implementation of the rolling subtree lookahead algorithm from "Rolling Lookahead Learning for Optimal Classification Trees" (published in IISE Transactions).

RolloTree builds interpretable decision trees by solving a sequence of small mixed-integer programs (MIPs). It starts with an optimal depth-2 tree and iteratively expands misclassified leaves, combining the scalability of greedy methods with the optimality guarantees of MIP-based approaches.

Features

  • Solver-agnostic — uses PuLP so you can choose between:
    • HiGHS (open-source, installed by default)
    • Gurobi (commercial, optional — install gurobipy separately)
    • CBC (open-source, bundled with PuLP)
  • sklearn-style APIfit() / predict() / score()
  • Two impurity criteria — Gini index or misclassification error
  • Arbitrary depth — depth-2 base tree extended via rolling subtree optimization
  • Parallel solvingn_jobs=-1 to solve independent subproblems across CPU cores

Installation

pip install rollotree

Or in editable/development mode (from a local clone):

pip install -e ".[dev]"

Dependencies: pulp, highspy, numpy, pandas.

For faster prediction and tree building with Numba JIT compilation:

pip install "rollotree[fast]"

To use the Gurobi solver backend:

pip install "rollotree[gurobi]"

Quick Start

import pandas as pd
from rollotree import RollingOCT

# Load data
train = pd.read_csv("rollotree/data/train.csv")
X_train = train.drop("y", axis=1)
y_train = train["y"]

test = pd.read_csv("rollotree/data/test.csv")
X_test = test.drop("y", axis=1)
y_test = test["y"]

# Train a depth-3 tree using the open-source HiGHS solver
model = RollingOCT(depth=3, criterion="gini", solver="highs")
model.fit(X_train, y_train)

# Evaluate
print(f"Test accuracy: {model.score(X_test, y_test):.3f}")

Parallel Execution

For deeper trees on larger datasets, use n_jobs to solve independent subproblems in parallel:

# Use all CPU cores for parallel subproblem solving
model = RollingOCT(depth=5, solver="highs", n_jobs=-1)
model.fit(X_train, y_train)

# Or specify an exact number of workers
model = RollingOCT(depth=5, solver="highs", n_jobs=4)
model.fit(X_train, y_train)

The rolling expansion at each depth level solves independent OCT-2 MIP subproblems for each parent node. With n_jobs > 1 these are dispatched across processes via ProcessPoolExecutor. Speedup scales with the number of parents per level — most effective at depth 4+.

API Reference

RollingOCT

RollingOCT(
    depth=2,              # Maximum tree depth (>= 2)
    criterion="gini",     # "gini" or "misclassification"
    solver="highs",       # "highs", "gurobi", or "cbc"
    time_limit=1800,      # Max seconds per depth-2 subproblem
    mip_gap=None,         # MIP optimality gap (e.g. 0.01 for 1%)
    big_m=99,             # Penalty for empty-leaf splits
    log_to_console=False,
    min_samples_split=2,  # Min samples to solve a subproblem
    min_samples_leaf=1,   # Min samples per leaf node
    n_jobs=1,             # Parallel workers: 1=sequential, -1=all cores
)

Methods:

Method Description
fit(X, y) Train the model. X is a binary feature matrix (DataFrame or array), y is the target. Returns self.
predict(X) Return class predictions as a numpy array.
score(X, y) Return accuracy (fraction correct).

Attributes (after fitting):

Attribute Description
tree_ The fitted DecisionTree object
depth_results_ Dict mapping depth → DepthResult(depth, training_accuracy, test_accuracy, elapsed_time)
classes_ Sorted list of unique class labels
features_ List of feature indices used

Solver Options

Solver Install Notes
"highs" pip install highspy (included in requirements) Best open-source MIP solver. Default.
"gurobi" pip install gurobipy + license Fastest commercial solver.
"cbc" Bundled with PuLP Fallback open-source solver.

Examples

See the examples/ directory for Jupyter notebooks:

Dataset

An example dataset is provided under rollotree/data/ — a binarized version of the Wine Dataset (3-class, 130 binary features).

How It Works

  1. OCT-2 Formulation: Solves a MIP to find the optimal depth-2 binary classification tree. Binary variables select which feature to split on at each node, and the objective minimizes total leaf impurity.

  2. Rolling Subtree (RST) Algorithm: Identifies misclassified leaves, groups them by parent, then solves a new OCT-2 subproblem for each parent's data subset. The resulting subtrees are merged back into the main tree. This repeats level-by-level until the target depth is reached.

  3. LP Relaxation Integrality: The OCT-2 formulation's feasible region is an integral polyhedron (Proposition 2 in the paper), so the LP relaxation always yields an integer-optimal solution.

Project Structure

rollotree/
    __init__.py              # Public API: RollingOCT, SolverConfig, etc.
    classifier.py            # RollingOCT (sklearn-like fit/predict)
    tree/
        nodes.py             # DecisionNode, LeafNode, DecisionTree
        impurity.py          # GiniCriterion, MisclassificationCriterion
        utils.py             # Node generation, index mapping helpers
        _numba.py            # Optional Numba-accelerated tree routing
    solver/
        base.py              # SolverStatus, OCT2Solution, SolverConfig
        pulp_solver.py       # PuLPOCT2Solver (HiGHS / Gurobi / CBC)
    rolling/
        optimizer.py         # RollingOptimizer, DepthResult
        parallel.py          # Multiprocessing worker for subproblem solving
    preprocessing/
        helpers.py           # Data binarization and preprocessing
    data/
        train.csv, test.csv  # Example Wine dataset
tests/                       # 90+ pytest test cases
benchmarks/                  # Performance benchmarks
examples/                    # Jupyter notebooks

Citation

@article{organ2026rolling,
      title={Rolling Lookahead Learning for Optimal Classification Trees},
      author={Zeynel Batuhan Organ and Enis Kayış and Taghi Khaniyev},
      journal={IISE Transactions},
      year={2026},
      publisher={Taylor \& Francis},
      doi={10.1080/24725854.2026.2613786},
      url={https://www.tandfonline.com/doi/abs/10.1080/24725854.2026.2613786}
}

A preprint is also available on arXiv (2304.10830).

Contact

Feel free to reach out with questions or feedback.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rollotree-1.3.0.tar.gz (47.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rollotree-1.3.0-py3-none-any.whl (42.4 kB view details)

Uploaded Python 3

File details

Details for the file rollotree-1.3.0.tar.gz.

File metadata

  • Download URL: rollotree-1.3.0.tar.gz
  • Upload date:
  • Size: 47.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for rollotree-1.3.0.tar.gz
Algorithm Hash digest
SHA256 d4512e035352b826ef98a091a5111f2c666bce29af2a3c7a555120864ec32f0f
MD5 a6d891e5791b1d60bc354e4207479e8f
BLAKE2b-256 932f7d1e7fbeb29cc868ec0e513feba561628926f4cfaa91f90d6b3e984f3776

See more details on using hashes here.

Provenance

The following attestation bundles were made for rollotree-1.3.0.tar.gz:

Publisher: publish.yml on koftezz/rolling-lookahead-dt

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file rollotree-1.3.0-py3-none-any.whl.

File metadata

  • Download URL: rollotree-1.3.0-py3-none-any.whl
  • Upload date:
  • Size: 42.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for rollotree-1.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2489c0a74b94ed806d95f0e9bf6afb07172d767c2095acacea92399213ced69e
MD5 843624c904d7672f190593ea409f5f23
BLAKE2b-256 1b0f1bf677833dc4df47c8c74829946ddf07e65877c1b7bfd1b9e19603fbcfcb

See more details on using hashes here.

Provenance

The following attestation bundles were made for rollotree-1.3.0-py3-none-any.whl:

Publisher: publish.yml on koftezz/rolling-lookahead-dt

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page