Skip to main content

Gradient Boosting libraries integrated with PyTorch

Project description

GBNet

DOI PyPI Downloads

PyTorch modules for XGBoost and LightGBM.

alt text


What is GBNet?

Gradient boosting (GBM) libraries like XGBoost and LightGBM are excellent for tabular data but can be cumbersome to extend with custom losses or model architectures because you must supply gradients and Hessians by hand.

GBNet wraps GBM libraries in PyTorch Modules so you can:

  • Define losses and architectures in plain PyTorch
  • Let PyTorch autograd compute gradients / Hessians
  • Use XGBoost / LightGBM / boosted linear layers as building blocks inside larger models

At the core of GBNet are three PyTorch Modules:

  • gbnet.xgbmodule.XGBModule – XGBoost as a PyTorch Module
  • gbnet.lgbmodule.LGBModule – LightGBM as a PyTorch Module
  • gbnet.gblinear.GBLinear – a linear PyTorch Module trained with boosting instead of via gradient descent methods

On top of these, GBNet ships higher-level models in gbnet.models, including forecasting, ordinal regression and survival models.


Installation

GBNet is on PyPI:

pip install gbnet

Using a virtual environment or conda environment is recommended. If you run into build / wheel issues for key dependencies (PyTorch, XGBoost, LightGBM), install them first following their platform-specific instructions, then install gbnet.


Quick Start: GBMs as PyTorch Modules

Basic pattern: treat XGBModule or LGBModule each as a PyTorch nn.Module, build the rest of your model architecture using Pytorch, and call gb_step() during training to advance the boosted model. Updating PyTorch components follows its usual step() logic.

import numpy as np
import torch
import xgboost as xgb

from gbnet import xgbmodule

# Toy regression data
np.random.seed(0)
n = 1000
input_dim = 20
output_dim = 1
X = np.random.random([n, input_dim])
B = np.random.random([input_dim, output_dim])
Y = X.dot(B) + 0.1 * np.random.randn(n, output_dim)

# XGBModule is a PyTorch Module wrapping XGBoost
model = xgbmodule.XGBModule(
    batch_size=n,
    input_dim=input_dim,
    output_dim=output_dim,
    params={}
)
loss_fn = torch.nn.MSELoss()

X_dmatrix = xgb.DMatrix(X)

losses = []
for _ in range(100):
    model.train()
    model.zero_grad()

    preds = model(X_dmatrix)
    loss = loss_fn(preds, torch.tensor(Y, dtype=torch.float32))
    loss.backward(create_graph=True)  # create_graph=True is required for gbnet
    losses.append(loss.item())

    model.gb_step()

model.eval()
preds = model(X_dmatrix)  # standard PyTorch-style inference
losses                    # decrease to near zero

Key ideas:

  • Gradients / Hessians are extracted from the PyTorch graph; no need to implement them manually.
  • gb_step() is the “one more boosting round” operation.
  • XGBModule and LGBModule, as sums of trees, cannot propagate gradients; thus they must sit in the first layer of your architecture.
  • Training data must stay fixed while training (no mini-batching). In this way, model training is closer to GBM training rather than Neural Network training.

Built-In Models

GBNet includes higher-level models built on these Modules. These live in gbnet.models and follow a scikit-learn-style fit/predict API.

Forecasting

gbnet.models.forecasting.Forecast provides a time-series model with trend + seasonality + changepoints using GBNet components. It is designed to be competitive with Prophet-style workflows while remaining flexible.

Minimal usage:

import pandas as pd
from gbnet.models import forecasting

# df has columns: 'ds' (datetime), 'y' (target)
df = pd.read_csv("your_timeseries.csv")
df["ds"] = pd.to_datetime(df["ds"])

model = forecasting.Forecast()
model.fit(df, df["y"])

forecast_df = model.predict(df)
print(forecast_df.head())

See examples/simple_forecast_example.ipynb for a more complete forecasting example.

Ordinal Regression

GBOrd in gbnet.models.ordinal_regression implements ordinal regression using GBMs with a PyTorch-defined loss.

  • Example notebook: examples/ordinal_regression_comparison.ipynb

Survival Models

GBNet includes survival analysis models under gbnet.models.survival, such as:

  • HazardSurvivalModel – continuous-time hazard model with a gradient-boosted hazard backbone
  • BetaSurvivalModel – discrete-time survival using Beta distributions with boosting
  • ThetaSurvivalModel – discrete-time survival via a geometric distribution parameterized by a GBM

Example notebooks:

  • examples/hazard_survival_example.ipynb
  • examples/discrete_survival_examples.ipynb

Project Layout and Resources

  • gbnet/ – core library:
    • xgbmodule.py, lgbmodule.py, gblinear.py
    • models/ – forecasting, ordinal regression, survival, more to come
  • examples/ – Jupyter notebooks:
    • simple_forecast_example.ipynb
    • gblinear_forecast_example.ipynb
    • ordinal_regression_comparison.ipynb
    • hazard_survival_example.ipynb
    • discrete_survival_examples.ipynb
  • docs/ – docs site

Start with the quick-start code above, then open the notebooks in examples/ to see end-to-end workflows.


Contributing

Contributions and issues are welcome. Typical ways to help:

  • Report bugs or performance issues
  • Propose or implement new models built on GBNet modules
  • Improve docs, notebooks, or examples

Before opening a pull request:

  1. Add or update tests for any new functionality.
  2. Run the existing test suite (e.g. pytest) if available in your environment.
  3. Keep code style consistent with the existing modules.

For larger changes, it’s helpful to open an issue first to discuss design.


Citation

If you use GBNet in academic work, please cite:

Horrell, M., (2025). GBNet: Gradient Boosting packages integrated into PyTorch.
Journal of Open Source Software, 10(111), 8047, https://doi.org/10.21105/joss.08047

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gbnet-0.7.4.tar.gz (47.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gbnet-0.7.4-py3-none-any.whl (58.8 kB view details)

Uploaded Python 3

File details

Details for the file gbnet-0.7.4.tar.gz.

File metadata

  • Download URL: gbnet-0.7.4.tar.gz
  • Upload date:
  • Size: 47.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for gbnet-0.7.4.tar.gz
Algorithm Hash digest
SHA256 1013b521fc101a90189d9552473a33d4ed8c3ef45398f1559419d2bce38fd869
MD5 b14bd705a38156e286cf76ed9ac2e223
BLAKE2b-256 b2e85e1e2a01632c0fb9cd94f4ae25c8323a086d140b7528692fe9b01066b583

See more details on using hashes here.

File details

Details for the file gbnet-0.7.4-py3-none-any.whl.

File metadata

  • Download URL: gbnet-0.7.4-py3-none-any.whl
  • Upload date:
  • Size: 58.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for gbnet-0.7.4-py3-none-any.whl
Algorithm Hash digest
SHA256 7ab80d757793784e3d927aa97b1016d8ac50f5a38a041a03c4eb51db48118df9
MD5 67854762e53e8ad9672c6dcb98816992
BLAKE2b-256 45a5db35e03926c0da8f02c058eb85f4db3d7eb0171c095f0a0cb3628f7ce9b8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page