Skip to main content

Probabilistic Gradient Boosting Machines in Pytorch

Project description

PGBM Airlab Amsterdam

PyPi version Python version GitHub license

Probabilistic Gradient Boosting Machines (PGBM) is a probabilistic gradient boosting framework in Python based on PyTorch, developed by Airlab in Amsterdam. It provides the following advantages over existing frameworks:

  • Probabilistic regression estimates instead of only point estimates.
  • Auto-differentiation of custom loss functions.
  • Native GPU-acceleration.

It is aimed at users interested in solving large-scale tabular probabilistic regression problems, such as probabilistic time series forecasting. For more details, read our paper or check out the examples.

Installation

Run pip install pgbm from a terminal within the virtual environment of your choice.

Verification

  • Download & run an example from the examples folder to verify the installation is correct. Use both gpu and cpu as device to check if you are able to train on both GPU and CPU.
  • Note that when training on the GPU, the custom CUDA kernel will be JIT-compiled when initializing a model. Hence, the first time you train a model on the GPU it can take a bit longer, as PGBM needs to compile the CUDA kernel.
  • When using the Numba-backend, several functions need to be JIT-compiled. Hence, the first time you train a model using this backend it can take a bit longer.

Dependencies

The core package has the following dependencies:

  • PyTorch >= 1.7.0, with CUDA 11.0 for GPU acceleration (https://pytorch.org/get-started/locally/)
  • Numpy >= 1.19.2 (install via pip or conda; https://github.com/numpy/numpy)
  • CUDA Toolkit 11.0 (or one matching your PyTorch distribution) (https://developer.nvidia.com/cuda-toolkit)
  • PGBM uses a custom CUDA kernel which needs to be compiled, which may require installing a suitable compiler. Installing PyTorch and the full CUDA Toolkit should be sufficient, but contact the author if you find it still not working even after installing these dependencies.
  • To run the experiments comparing against baseline models a number of additional packages may need to be installed via pip or conda.

We also provide PGBM based on a Numba backend for those users who do not want to use PyTorch. In that case, it is required to install Numba. The Numba backend does not support differentiable loss functions. For an example of using PGBM with the Numba backend, see the examples.

Support

See the examples folder for examples, an overview of hyperparameters and a function reference. In general, PGBM works similar to existing gradient boosting packages such as LightGBM or xgboost (and it should be possible to more or less use it as a drop-in replacement), except that it is required to explicitly define a loss function and loss metric.

In case further support is required, open an issue.

Reference

Olivier Sprangers, Sebastian Schelter, Maarten de Rijke. Probabilistic Gradient Boosting Machines for Large-Scale Probabilistic Regression. Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD ’21), August 14–18, 2021, Virtual Event, Singapore.

The experiments from our paper can be replicated by running the scripts in the experiments folder. Datasets are downloaded when needed in the experiments except for higgs and m5, which should be pre-downloaded and saved to the datasets folder (Higgs) and to datasets/m5 (m5).

License

This project is licensed under the terms of the Apache 2.0 license.

Acknowledgements

This project was developed by Airlab Amsterdam.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pgbm-0.4.tar.gz (29.7 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

pgbm-0.4-py3.8.egg (45.0 kB view details)

Uploaded Egg

pgbm-0.4-py3-none-any.whl (30.2 kB view details)

Uploaded Python 3

File details

Details for the file pgbm-0.4.tar.gz.

File metadata

  • Download URL: pgbm-0.4.tar.gz
  • Upload date:
  • Size: 29.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.8.5

File hashes

Hashes for pgbm-0.4.tar.gz
Algorithm Hash digest
SHA256 b4ddb54c9131a8245afcf76a3619760475ae5d0c1e78f635b1e9119403e752f5
MD5 c2684c0aea313da74aa94d533255f1af
BLAKE2b-256 d1d37d9cf1e9530490de49fab6a60aecf91971aef3ff27105305707573b72703

See more details on using hashes here.

File details

Details for the file pgbm-0.4-py3.8.egg.

File metadata

  • Download URL: pgbm-0.4-py3.8.egg
  • Upload date:
  • Size: 45.0 kB
  • Tags: Egg
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.8.5

File hashes

Hashes for pgbm-0.4-py3.8.egg
Algorithm Hash digest
SHA256 4714c272d12deaa20bb8f01055a1c9deb5786b6d0d32a31e66109e4ca9d98f06
MD5 6ec35e0f714adadd368bf6e5d7a6557a
BLAKE2b-256 266137a88a8c603d4b9ecf1191ee70339e39541bdae8cd61bdbf24e40c89ffea

See more details on using hashes here.

File details

Details for the file pgbm-0.4-py3-none-any.whl.

File metadata

  • Download URL: pgbm-0.4-py3-none-any.whl
  • Upload date:
  • Size: 30.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.8.5

File hashes

Hashes for pgbm-0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 2c6edd18e6e3f39a2584fa50641df1a5418ed54cb5b9cbefbdc01f84bc3ea2a6
MD5 ea72dbe25935fb25a475fa1b4fe354ee
BLAKE2b-256 c837df2895e131052217c84c459d315b099499733c6de710406cff2452d545dc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page