Skip to main content

Representation-enhanced Leaf Gradient Boosting Machine: raw-feature tree routing with representation-conditioned leaf models.

Project description

RepLeafGBM

CI PyPI License: MIT

Representation-enhanced Leaf Gradient Boosting Machine — gradient boosting that routes on raw features and predicts with small linear models over learned representations inside each leaf.

RepLeafGBM is not a neural network inside a tree, nor a tree over embeddings. It is a boosted ensemble of raw-feature routers with representation-conditioned local predictors.

How it works

GBDTs dominate tabular ML because axis-aligned splits on raw features handle discontinuities, interactions, and messy data extremely well. Tabular deep learning has shown that numerical feature embeddings (PLR, periodic) capture smooth nonlinear structure that constant-leaf trees approximate only with many splits. RepLeafGBM combines both with a deliberately asymmetric design:

  • Routing — every split is on a raw feature, exactly like a normal GBDT. Trees find the discontinuous boundaries and partition the space into local regions.
  • Leaf output — each leaf holds a small ridge-regularized linear model over a learned representation z_theta(x) instead of a constant. The embedding does the smooth interpolation within each region.
f_t(x) = b + w^T z_theta(x)            # one tree: route on x_raw, predict in the leaf
F_T(x) = F_0(x) + sum_t eta * f_t(x)   # boosted sum

Three properties fall out of this design:

  • Embeddings are never used for splitting → interpretable raw-feature routing, no split-histogram blow-up, no curse of dimensionality in the search.
  • The encoder is frozen during boosting (v0), preserving the stage-wise additive structure of gradient boosting.
  • Leaf models fall back to a constant when a leaf is too small, so the model degrades gracefully toward a classic GBDT.

See docs/math.md for the full formulation and leaf fitting.

Installation

pip install repleafgbm                 # core (numpy, pandas, scikit-learn)
pip install repleafgbm-native          # + optional Rust split/leaf kernels (auto-detected)
pip install "repleafgbm[external]"     # + LightGBM external_model / router_extraction
pip install "repleafgbm[bench]"        # + XGBoost / CatBoost for benchmarks
pip install "repleafgbm[torch]"        # + learned torch encoders

repleafgbm ships type information (PEP 561). repleafgbm-native is a separate package of prebuilt Linux/macOS/Windows wheels; once installed the Rust backend is selected automatically (split_backend="auto"), giving ~5.8x faster constant-leaf training while staying parity-tested against the NumPy reference.

Quickstart

import numpy as np
from repleafgbm import RepLeafRegressor
from repleafgbm.data import RepLeafDataset

rng = np.random.default_rng(0)
X = rng.normal(size=(500, 4))
y = np.where(X[:, 0] > 0, 3.0, -2.0) + 2.0 * X[:, 1] + rng.normal(0, 0.1, 500)

model = RepLeafRegressor(
    n_estimators=50,
    learning_rate=0.1,
    num_leaves=8,
    leaf_model="embedded_linear",   # or "constant", "raw_linear"
    encoder="plr",                  # or "identity", "periodic", "cross"
    max_leaf_emb_dim=16,
    random_state=42,
)
model.fit(X, y)
pred = model.predict(X)

model.save_model("repleaf_model")
loaded = RepLeafRegressor.load_model("repleaf_model")

pandas DataFrames with categorical columns are supported through the dataset API:

train_data = RepLeafDataset(df_train, y_train, categorical_features=["city"])
model.fit(train_data, eval_set=[RepLeafDataset(df_valid, y_valid,
                                               metadata=train_data.metadata)])

API overview

The public API is scikit-learn compatible (fit / predict / predict_proba, get_params / set_params).

Main classes

Class Use
RepLeafRegressor Regression (squared error, plus huber / quantile / poisson); multi-output via vector leaves.
RepLeafClassifier Binary (logistic) and multiclass (softmax) — chosen automatically from the labels. Has predict_proba.
RepLeafDataset pandas / categorical inputs, eval sets, and embedding caching.

Key parameters (constructor; same for both estimators)

Parameter Default Meaning
n_estimators 100 Boosting rounds (trees).
learning_rate 0.1 Shrinkage per round.
num_leaves 31 Max leaves per tree (leaf-wise growth).
min_samples_leaf 20 Minimum rows per leaf.
leaf_model "embedded_linear" Leaf predictor (see below).
encoder "identity" Representation z_theta(x) (see below).
max_leaf_emb_dim 64 Cap on embedding dimension (random projection above it).
l2_leaf 1.0 Ridge penalty for leaf models.
early_stopping_rounds None Stop when an eval_set metric plateaus.
random_state 42 Seed; same seed ⇒ same model.

leaf_model

Value Leaf predicts
"constant" A constant (classic GBDT leaf).
"embedded_linear" Ridge linear model over z_theta(x) (default).
"raw_linear" Ridge linear model over the raw features.

encoder

Value Representation
"identity" Standardized raw features. Evidence-backed default on real tabular data.
"plr" Piecewise-linear + linear term.
"periodic" Frozen sinusoidal (PBLD-style) features.
"cross" Residual-correlated pairwise products (interactions).
"torch_periodic" / "torch_plr" / "torch_periodic_plr" / "torch_mlp" Learned encoders ([torch] extra; supervised-pretrained on the initial residual then frozen — torch is needed only at fit time, not at predict).

API stability follows Semantic Versioning from 1.0.0; exactly what is covered vs. experimental is in docs/adr/0003-api-stability.md.

Features

Area Support
Backends NumPy reference (histogram split search, leaf-wise growth) + optional Rust kernels (repleafgbm-native, ~5.8x faster, parity-tested).
Tasks Regression, binary & multiclass classification, multi-output regression (vector leaves).
Objectives Squared error, huber, quantile, poisson, logistic, softmax (parameterized instances like Quantile(alpha=0.9) too).
Leaf models constant, embedded_linear, raw_linear.
Encoders identity, plr, periodic, cross, learned torch_*.
Training Early stopping, eval metrics (rmse, mae, logloss, multi_logloss, auc, accuracy, or a custom callable via make_metric), feature importances, sample weights, class_weight, label_smoothing.
Data RepLeafDataset with pandas/categorical (native subset splits) and embedding caching.
Persistence Directory-based save_model / load_model with schema validation and a human-readable summary().
External LightGBM / XGBoost / CatBoost as base models, OOF + stacking utilities, and RouterExtraction{Regressor,Classifier} ([external] extra).

Full implemented-vs-planned status (and what is intentionally not done yet — encoder updates during boosting, GPU / distributed training) is in docs/roadmap.md.

Benchmarks

These small benchmarks track development progress; they are not performance claims. OpenML mean rank across 9 standard datasets (lower is better):

Model Regression (4) Classification (5)
CatBoost 2.00 2.00
RepLeaf (constant) 4.00 2.60
XGBoost 2.75 3.40
LightGBM 2.50 3.60
RepLeaf (embedded_linear) 4.50 4.20
HistGradientBoosting 5.25 5.20

RepLeafGBM is competitive with the major GBM libraries on real data, where a constant leaf is the honest choice. The leaf embeddings pay off on smooth / periodic structure: on such synthetic signals embedded_linear beats both constant and LightGBM, and refitting LightGBM's own routes with representation-conditioned leaves (router_extraction) improves it by 2–12% RMSE. Reproduce with python benchmarks/openml_suite.py; full numbers in experiments/results/openml_benchmark.md.

Development

git clone https://github.com/Matapanino/repleafgbm.git && cd repleafgbm
pip install -e ".[dev]"
bash scripts/check.sh               # lint + tests + all examples
python -m pytest tests/ -q          # PYTHONPATH=src if not installed
python examples/regression_basic.py

See CONTRIBUTING.md for the contribution workflow.

Documentation

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

repleafgbm-1.5.0.tar.gz (254.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

repleafgbm-1.5.0-py3-none-any.whl (112.8 kB view details)

Uploaded Python 3

File details

Details for the file repleafgbm-1.5.0.tar.gz.

File metadata

  • Download URL: repleafgbm-1.5.0.tar.gz
  • Upload date:
  • Size: 254.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for repleafgbm-1.5.0.tar.gz
Algorithm Hash digest
SHA256 dc1af063da1a0cdeb3d3dd4785dd1846c38e23e55153635647cfbe5bea80bb7a
MD5 fa26f59fdea1ff6fb4698f7b3d07fb9d
BLAKE2b-256 eca9804959e9f38d190e87dec7a4362cdea93ab683f0937718069653162552dd

See more details on using hashes here.

Provenance

The following attestation bundles were made for repleafgbm-1.5.0.tar.gz:

Publisher: publish.yml on Matapanino/repleafgbm

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file repleafgbm-1.5.0-py3-none-any.whl.

File metadata

  • Download URL: repleafgbm-1.5.0-py3-none-any.whl
  • Upload date:
  • Size: 112.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for repleafgbm-1.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d564e8929618d09d030dd8008a403d72cd21c8eb4cb2c1bbf6cc4ec70ad67918
MD5 7171e3fa384732c356ef8ad8a4dfbd4d
BLAKE2b-256 f41b11189635cf400c6b4a95a0706b2b38bac10bb0e9bd689645fbaac62eb6ea

See more details on using hashes here.

Provenance

The following attestation bundles were made for repleafgbm-1.5.0-py3-none-any.whl:

Publisher: publish.yml on Matapanino/repleafgbm

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page