Skip to main content

XGBoost for probabilistic prediction.

Project description

https://github.com/CDonnerer/xgboost-distribution/actions/workflows/test.yml/badge.svg?branch=main https://coveralls.io/repos/github/CDonnerer/xgboost-distribution/badge.svg?branch=main Documentation Status PyPI-Server

xgboost-distribution

XGBoost for probabilistic prediction. Like NGBoost, but faster and in the XGBoost scikit-learn API.

XGBDistribution example

Installation

$ pip install --upgrade xgboost-distribution

Usage

XGBDistribution follows the XGBoost scikit-learn API, except for an additional keyword in the constructor for specifying the distribution. Given some data, we can fit a model:

from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split

from xgboost_distribution import XGBDistribution

data = load_boston()
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y)

model = XGBDistribution(
    distribution="normal",
    n_estimators=500
)
model.fit(
    X_train, y_train,
    eval_set=[(X_test, y_test)],
    early_stopping_rounds=10
)

After fitting, we can predict the parameters of the distribution for new data. This will return a namedtuple of numpy arrays for each parameter of the distribution (note that we use scipy naming conventions, see e.g. scipy.stats.norm):

preds = model.predict(X_test)
mean, std = preds.loc, preds.scale

NGBoost performance comparison

XGBDistribution follows the method shown in the NGBoost library, namely using natural gradients to estimate the parameters of the distribution.

Below, we show a performance comparison of the NGBoost NGBRegressor and XGBDistribution models, using the Boston Housing dataset and a normal distribution (similar hyperparameters). We note that while the performance of the two models is essentially identical, XGBDistribution is 50x faster (timed on both fit and predict steps).

Note that the speed-up will decrease with dataset size, as it is ultimately limited by the natural gradient computation (via LAPACK gesv), with 1m rows of data XGBDistribution is still 10x faster than NGBRegressor.

XGBDistribution vs NGBoost

Full XGBoost features

XGBDistribution offers the full set of XGBoost features available in the XGBoost scikit-learn API, allowing, for example, probabilistic prediction with monotonic constraints:

XGBDistribution monotonic constraints

Note

This project has been set up using PyScaffold 4.0.1. For details and usage information on PyScaffold see https://pyscaffold.org/.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xgboost-distribution-0.1.0.tar.gz (198.2 kB view details)

Uploaded Source

Built Distribution

xgboost_distribution-0.1.0-py2.py3-none-any.whl (9.8 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file xgboost-distribution-0.1.0.tar.gz.

File metadata

  • Download URL: xgboost-distribution-0.1.0.tar.gz
  • Upload date:
  • Size: 198.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.9.5

File hashes

Hashes for xgboost-distribution-0.1.0.tar.gz
Algorithm Hash digest
SHA256 2775cd544334b2539630bab9c0ad8425906b40a8e3e7ea0d89b62b97a6537823
MD5 04347cdef57ed51a4d0282c6ba1c7332
BLAKE2b-256 106d0e36501e76b7d4fa59996fcb672284200098349308a6229fa94a46d292cf

See more details on using hashes here.

File details

Details for the file xgboost_distribution-0.1.0-py2.py3-none-any.whl.

File metadata

  • Download URL: xgboost_distribution-0.1.0-py2.py3-none-any.whl
  • Upload date:
  • Size: 9.8 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.9.5

File hashes

Hashes for xgboost_distribution-0.1.0-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 0701388b643609401c9987339ab191cd87cbca103bbf7099944ced4ca4bf516b
MD5 012571fb7bc4d901c0051f3e0bd83399
BLAKE2b-256 1d439f49090d2fdd7bd33697bb0edfa17866e333f872da76d5d8ba9cc5d86036

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page