Skip to main content

Representation-enhanced Leaf Gradient Boosting Machine: raw-feature tree routing with representation-conditioned leaf models.

Project description

RepLeafGBM

CI PyPI License: MIT

Representation-enhanced Leaf Gradient Boosting Machine — gradient boosting that routes on raw features and predicts with small linear models over learned representations inside each leaf.

RepLeafGBM is not a neural network inside a tree, nor a tree over embeddings. It is a boosted ensemble of raw-feature routers with representation-conditioned local predictors.

Stability: from 1.0.0 the public API follows Semantic Versioning. What that covers (estimator parameters, the save/load format, registered encoder/objective/metric names, exported symbols) and what stays experimental (repleafgbm.external, router_extraction, native-backend internals) is spelled out in docs/adr/0003-api-stability.md.

Highlights from the experiment log (see docs/audit_v0.md and experiments/results/):

  • On synthetic signals with smooth structure inside regimes, embedded-linear leaves beat constant leaves, LightGBM, and sklearn HistGradientBoosting.
  • Refitting LightGBM's own routes with representation-conditioned leaves (router_extraction) improves LightGBM by 2-12% RMSE — isolating the leaf contribution from split quality.
  • On standard real OpenML datasets RepLeafGBM is competitive with the major GBM libraries (mean rank across 9 datasets; constant-leaf RepLeaf even edges out LightGBM and XGBoost on classification), though there the leaf embeddings add no real-data accuracy over a constant leaf — consistent with the documented finding that their advantage is specific to smooth/periodic structure. Reproduce: python benchmarks/openml_suite.py (see experiments/results/openml_benchmark.md).

Motivation

GBDTs dominate tabular ML because axis-aligned splits on raw features handle discontinuities, interactions, and messy data extremely well. Tabular deep learning has shown that numerical feature embeddings (PLR, periodic) capture smooth nonlinear structure that constant-leaf trees approximate only with many splits.

RepLeafGBM combines both with a deliberately asymmetric design:

  • Routing — every split is on a raw feature, exactly like a normal GBDT. Trees do what trees are good at: finding discontinuous boundaries and partitioning the space into local regions.
  • Leaf output — each leaf holds a small ridge-regularized linear model over a learned representation z_theta(x) instead of a constant. The embedding does what it is good at: smooth interpolation within a local region.
f_t(x) = b_{t, l_t(x_raw)} + w_{t, l_t(x_raw)}^T z_theta(x)
F_T(x) = F_0(x) + sum_t eta * f_t(x)

What makes it different

  • Embeddings are never used for splitting (interpretable raw-feature routing, no histogram blow-up, no curse of dimensionality in split search).
  • The encoder is frozen during boosting in v0, preserving the stage-wise additive structure of gradient boosting.
  • Leaf models fall back to constants when leaves are too small, so the model degrades gracefully toward a classic GBDT.

Minimal example

import numpy as np
from repleafgbm import RepLeafRegressor
from repleafgbm.data import RepLeafDataset

rng = np.random.default_rng(0)
X = rng.normal(size=(500, 4))
y = np.where(X[:, 0] > 0, 3.0, -2.0) + 2.0 * X[:, 1] + rng.normal(0, 0.1, 500)

model = RepLeafRegressor(
    n_estimators=50,
    learning_rate=0.1,
    num_leaves=8,
    leaf_model="embedded_linear",   # or "constant", "raw_linear"
    encoder="plr",                  # or "identity"
    max_leaf_emb_dim=16,
    random_state=42,
)
model.fit(X, y)
pred = model.predict(X)

model.save_model("repleaf_model")
loaded = RepLeafRegressor.load_model("repleaf_model")

pandas DataFrames with categorical columns are supported through the dataset API:

train_data = RepLeafDataset(df_train, y_train, categorical_features=["city"])
model.fit(train_data, eval_set=[RepLeafDataset(df_valid, y_valid,
                                               metadata=train_data.metadata)])

Current status (v0)

Implemented:

  • Native NumPy backend: histogram-based split search with sibling-histogram subtraction, leaf-wise tree growth — plus optional Rust kernels (pip install repleafgbm-native, or pip install ./native from source; auto-detected; ~5.8x faster constant-leaf training, parity-tested against the NumPy reference)
  • leaf_model: "constant", "embedded_linear", "raw_linear"
  • Encoders: "identity", "plr" (simplified piecewise-linear + linear term), "periodic" (PBLD-style frozen sinusoidal features), "cross" (residual-correlated pairwise products), and learned "torch_periodic" / "torch_plr" / "torch_mlp" (optional [torch] extra; pretrained on the initial residual then frozen — torch is needed only at fit time). identity is the evidence-backed default on real tabular data; the others are specialists for known smooth/oscillatory or interaction structure (see docs for guidance). Random projection down to max_leaf_emb_dim as an emergency cap
  • Regression (squared error, plus objective="huber" / "quantile" / "poisson" — parameterized instances like Quantile(alpha=0.9) work too), binary classification (logistic), and multiclass classification (softmax, one tree per class per round — automatic at 3+ classes)
  • Multi-output regression via shared-routing vector leaves (pass a 2-D y; one tree per round whose leaves emit a vector — squared-error only), and label_smoothing for classification
  • Early stopping (early_stopping_rounds, best_iteration_, prediction at the best iteration) and eval metrics: rmse, mae, logloss, multi_logloss, auc, accuracy, or any user-supplied callable (repleafgbm.make_metric)
  • Feature importance (feature_importances_, gain or split count)
  • RepLeafDataset with pandas/categorical support (native subset splits, pandas.Categorical order fidelity, opt-in frequency encoding) and embedding caching
  • Directory-based save_model / load_model with schema validation and a human-readable summary.txt (model.summary())
  • repleafgbm.external: LightGBM, XGBoost, and CatBoost as external base models (scores + leaf indices, optional native early stopping), generic OOF utility, stacking feature builders, and RouterExtractionRegressor / RouterExtractionClassifier — LightGBM routing with RepLeaf leaf models refit on the frozen routes, with replay-stage early stopping. pip install "repleafgbm[external]" pulls LightGBM (the base used by router_extraction); XGBoost/CatBoost support is exercised via the [bench] extra.
  • pytest suite, runnable examples, and an experiments/ research scaffold

Not implemented (see docs/roadmap.md): encoder updates during boosting, GPU/distributed training.

Installation

pip install repleafgbm                 # core (numpy, pandas, scikit-learn)
pip install repleafgbm-native          # + optional Rust split/leaf kernels (auto-detected)
pip install "repleafgbm[external]"     # + LightGBM external_model / router_extraction
pip install "repleafgbm[bench]"        # + XGBoost / CatBoost for benchmarks
pip install "repleafgbm[torch]"        # + learned torch encoders

repleafgbm ships type information (PEP 561) — type checkers see its public API out of the box. repleafgbm-native is a separate package of prebuilt Linux/macOS/Windows wheels; once installed the Rust backend is selected automatically (split_backend="auto"), giving ~5.8x faster constant-leaf training while staying parity-tested against the NumPy reference.

Development

git clone <repo-url> && cd repleafgbm
pip install -e ".[dev]"

Or without installing, run everything from the repo root with PYTHONPATH=src. Build the API reference with pip install -e ".[docs]" && bash scripts/build_docs.sh.

Running tests and examples

bash scripts/check.sh               # lint + tests + all examples
python -m pytest tests/ -q          # PYTHONPATH=src if not installed
python examples/regression_basic.py
python examples/binary_classification_basic.py
python examples/multiclass_classification_basic.py
python examples/dataset_api_basic.py

Benchmarks

Small synthetic benchmarks track progress across development (they are not performance claims):

python benchmarks/benchmark_synthetic_regression.py [--quick]
python benchmarks/benchmark_synthetic_binary.py [--quick]

LightGBM / XGBoost / CatBoost are included automatically when installed (pip install -e ".[bench]"). The latest snapshot and analysis live in docs/audit_v0.md.

Documentation

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

repleafgbm-1.0.2.tar.gz (194.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

repleafgbm-1.0.2-py3-none-any.whl (98.8 kB view details)

Uploaded Python 3

File details

Details for the file repleafgbm-1.0.2.tar.gz.

File metadata

  • Download URL: repleafgbm-1.0.2.tar.gz
  • Upload date:
  • Size: 194.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for repleafgbm-1.0.2.tar.gz
Algorithm Hash digest
SHA256 c11f46e51b06b3a928e3b6d8e9a76f896bbf3a27c0b4cf0c89a37abfdf5f7c00
MD5 8374f1b8117e7cbc4063d3467af85453
BLAKE2b-256 6cb2e480067c155c627d7293bc1aae93dd19afa252a5b8a04608378302dec76c

See more details on using hashes here.

Provenance

The following attestation bundles were made for repleafgbm-1.0.2.tar.gz:

Publisher: publish.yml on Matapanino/repleafgbm

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file repleafgbm-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: repleafgbm-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 98.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for repleafgbm-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 176da8580c51f9959a8e1425c6960f0ea345633c3020c74d39784967e1607bfb
MD5 ada23fc47f5015d49b8cd20cc5a0ee2e
BLAKE2b-256 9cac29877634dea5615f16f10f7ea721542b5bea4e1d53634aae4f64a8a18577

See more details on using hashes here.

Provenance

The following attestation bundles were made for repleafgbm-1.0.2-py3-none-any.whl:

Publisher: publish.yml on Matapanino/repleafgbm

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page