Skip to main content

Representation-enhanced Leaf Gradient Boosting Machine: raw-feature tree routing with representation-conditioned leaf models.

Project description

RepLeafGBM

CI PyPI License: MIT

Representation-enhanced Leaf Gradient Boosting Machine — gradient boosting that routes on raw features and predicts with small linear models over learned representations inside each leaf.

RepLeafGBM is not a neural network inside a tree, nor a tree over embeddings. It is a boosted ensemble of raw-feature routers with representation-conditioned local predictors.

How it works

GBDTs dominate tabular ML because axis-aligned splits on raw features handle discontinuities, interactions, and messy data extremely well. Tabular deep learning has shown that numerical feature embeddings (PLR, periodic) capture smooth nonlinear structure that constant-leaf trees approximate only with many splits. RepLeafGBM combines both with a deliberately asymmetric design:

  • Routing — every split is on a raw feature, exactly like a normal GBDT. Trees find the discontinuous boundaries and partition the space into local regions.
  • Leaf output — each leaf holds a small ridge-regularized linear model over a learned representation z_theta(x) instead of a constant. The embedding does the smooth interpolation within each region.
f_t(x) = b + w^T z_theta(x)            # one tree: route on x_raw, predict in the leaf
F_T(x) = F_0(x) + sum_t eta * f_t(x)   # boosted sum

Three properties fall out of this design:

  • Embeddings are never used for splitting → interpretable raw-feature routing, no split-histogram blow-up, no curse of dimensionality in the search.
  • The encoder is frozen during boosting (v0), preserving the stage-wise additive structure of gradient boosting.
  • Leaf models fall back to a constant when a leaf is too small, so the model degrades gracefully toward a classic GBDT.

See docs/math.md for the full formulation and leaf fitting.

Installation

pip install repleafgbm                 # core (numpy, pandas, scikit-learn)
pip install repleafgbm-native          # + optional Rust split/leaf kernels (auto-detected)
pip install "repleafgbm[external]"     # + LightGBM external_model / router_extraction
pip install "repleafgbm[bench]"        # + XGBoost / CatBoost for benchmarks
pip install "repleafgbm[torch]"        # + learned torch encoders

repleafgbm ships type information (PEP 561). repleafgbm-native is a separate package of prebuilt Linux/macOS/Windows wheels; once installed the Rust backend is selected automatically (split_backend="auto"), giving ~5.8x faster constant-leaf training while staying parity-tested against the NumPy reference.

Quickstart

import numpy as np
from repleafgbm import RepLeafRegressor
from repleafgbm.data import RepLeafDataset

rng = np.random.default_rng(0)
X = rng.normal(size=(500, 4))
y = np.where(X[:, 0] > 0, 3.0, -2.0) + 2.0 * X[:, 1] + rng.normal(0, 0.1, 500)

model = RepLeafRegressor(
    n_estimators=50,
    learning_rate=0.1,
    num_leaves=8,
    leaf_model="embedded_linear",   # or "constant", "raw_linear"
    encoder="plr",                  # or "identity", "periodic", "cross"
    max_leaf_emb_dim=16,
    random_state=42,
)
model.fit(X, y)
pred = model.predict(X)

model.save_model("repleaf_model")
loaded = RepLeafRegressor.load_model("repleaf_model")

pandas DataFrames with categorical columns are supported through the dataset API:

train_data = RepLeafDataset(df_train, y_train, categorical_features=["city"])
model.fit(train_data, eval_set=[RepLeafDataset(df_valid, y_valid,
                                               metadata=train_data.metadata)])

API overview

The public API is scikit-learn compatible (fit / predict / predict_proba, get_params / set_params).

Main classes

Class Use
RepLeafRegressor Regression (squared error, plus huber / quantile / poisson); multi-output via vector leaves.
RepLeafClassifier Binary (logistic) and multiclass (softmax) — chosen automatically from the labels. Has predict_proba.
RepLeafDataset pandas / categorical inputs, eval sets, and embedding caching.

Key parameters (constructor; same for both estimators)

Parameter Default Meaning
n_estimators 100 Boosting rounds (trees).
learning_rate 0.1 Shrinkage per round.
num_leaves 31 Max leaves per tree (leaf-wise growth).
min_samples_leaf 20 Minimum rows per leaf.
leaf_model "embedded_linear" Leaf predictor (see below).
encoder "identity" Representation z_theta(x) (see below).
max_leaf_emb_dim 64 Cap on embedding dimension (random projection above it).
l2_leaf 1.0 Ridge penalty for leaf models.
early_stopping_rounds None Stop when an eval_set metric plateaus.
random_state 42 Seed; same seed ⇒ same model.

leaf_model

Value Leaf predicts
"constant" A constant (classic GBDT leaf).
"embedded_linear" Ridge linear model over z_theta(x) (default).
"raw_linear" Ridge linear model over the raw features.

encoder

Value Representation
"identity" Standardized raw features. Evidence-backed default on real tabular data.
"plr" Piecewise-linear + linear term.
"periodic" Frozen sinusoidal (PBLD-style) features.
"cross" Residual-correlated pairwise products (interactions).
"torch_periodic" / "torch_plr" / "torch_periodic_plr" / "torch_mlp" Learned encoders ([torch] extra; supervised-pretrained on the initial residual then frozen — torch is needed only at fit time, not at predict).

API stability follows Semantic Versioning from 1.0.0; exactly what is covered vs. experimental is in docs/adr/0003-api-stability.md.

Features

Area Support
Backends NumPy reference (histogram split search, leaf-wise growth) + optional Rust kernels (repleafgbm-native, ~5.8x faster, parity-tested).
Tasks Regression, binary & multiclass classification, multi-output regression (vector leaves).
Objectives Squared error, huber, quantile, poisson, logistic, softmax (parameterized instances like Quantile(alpha=0.9) too).
Leaf models constant, embedded_linear, raw_linear.
Encoders identity, plr, periodic, cross, learned torch_*.
Training Early stopping, eval metrics (rmse, mae, logloss, multi_logloss, auc, accuracy, or a custom callable via make_metric), feature importances, sample weights, class_weight, label_smoothing.
Data RepLeafDataset with pandas/categorical (native subset splits) and embedding caching.
Persistence Directory-based save_model / load_model with schema validation and a human-readable summary().
External LightGBM / XGBoost / CatBoost as base models, OOF + stacking utilities, and RouterExtraction{Regressor,Classifier} ([external] extra).

Full implemented-vs-planned status (and what is intentionally not done yet — encoder updates during boosting, GPU / distributed training) is in docs/roadmap.md.

Benchmarks

These small benchmarks track development progress; they are not performance claims. OpenML mean rank across 9 standard datasets (lower is better):

Model Regression (4) Classification (5)
CatBoost 2.00 2.00
RepLeaf (constant) 4.00 2.60
XGBoost 2.75 3.40
LightGBM 2.50 3.60
RepLeaf (embedded_linear) 4.50 4.20
HistGradientBoosting 5.25 5.20

RepLeafGBM is competitive with the major GBM libraries on real data, where a constant leaf is the honest choice. The leaf embeddings pay off on smooth / periodic structure: on such synthetic signals embedded_linear beats both constant and LightGBM, and refitting LightGBM's own routes with representation-conditioned leaves (router_extraction) improves it by 2–12% RMSE. Reproduce with python benchmarks/openml_suite.py; full numbers in experiments/results/openml_benchmark.md.

Development

git clone https://github.com/Matapanino/repleafgbm.git && cd repleafgbm
pip install -e ".[dev]"
bash scripts/check.sh               # lint + tests + all examples
python -m pytest tests/ -q          # PYTHONPATH=src if not installed
python examples/regression_basic.py

See CONTRIBUTING.md for the contribution workflow.

Documentation

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

repleafgbm-1.4.0.tar.gz (244.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

repleafgbm-1.4.0-py3-none-any.whl (110.7 kB view details)

Uploaded Python 3

File details

Details for the file repleafgbm-1.4.0.tar.gz.

File metadata

  • Download URL: repleafgbm-1.4.0.tar.gz
  • Upload date:
  • Size: 244.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for repleafgbm-1.4.0.tar.gz
Algorithm Hash digest
SHA256 d034221cfc4fdf64a8c7c2eff41a61b45aa6b774f74a71b486f04e4e716f1285
MD5 43564aa26ea5edb3415b885ea8153127
BLAKE2b-256 ffba41c9de913bc3dbf0d8ef46422d331530d6d5355654c0cb14ef19726deb64

See more details on using hashes here.

Provenance

The following attestation bundles were made for repleafgbm-1.4.0.tar.gz:

Publisher: publish.yml on Matapanino/repleafgbm

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file repleafgbm-1.4.0-py3-none-any.whl.

File metadata

  • Download URL: repleafgbm-1.4.0-py3-none-any.whl
  • Upload date:
  • Size: 110.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for repleafgbm-1.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f43cab600558f8b3a98235620de07162cd6878a7e5a2d9a64c566df9c3ed7862
MD5 280ada7d647364686da1f864ef4a5693
BLAKE2b-256 cc04d858986d818e6664857604695b3e451f6d8fa9b5e68110deea5704e94e7b

See more details on using hashes here.

Provenance

The following attestation bundles were made for repleafgbm-1.4.0-py3-none-any.whl:

Publisher: publish.yml on Matapanino/repleafgbm

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page