Skip to main content

A lightweight gradient boosting implementation in Rust.

Project description

PyPI Crates.io

Forust

A lightweight gradient boosting package

Forust, is a lightweight package for building gradient boosted decision tree ensembles. All of the algorithm code is written in Rust, with a python wrapper. The rust package can be used directly, however, most examples shown here will be for the python wrapper. It implements the same algorithm as the XGBoost package, and in many cases will give nearly identical results.

I developed this package for a few reasons, mainly to better understand the XGBoost algorithm, additionally to have a fun project to work on in rust, and because I wanted to be able to experiment with adding new features to the algorithm in a smaller simpler codebase.

All of the rust code for the package can be found in the src directory, while all of the python wrapper code is in the py-forust directory.

Installation

The package can be installed directly from pypi.

pip install forust

To use in a rust project add the following to your Cargo.toml file.

forust-ml = "0.1.5"

Usage

The GradientBooster class is currently the only public facing class in the package, and can be used to train gradient boosted decision tree ensembles with multiple objective functions.

It can be initialized with the following arguments.

  • objective_type (str, optional): The name of objective function used to optimize. Valid options include "LogLoss" to use logistic loss as the objective function (binary classification), or "SquaredLoss" to use Squared Error as the objective function (continuous regression). Defaults to "LogLoss".
  • iterations (int, optional): Total number of trees to train in the ensemble. Defaults to 100.
  • learning_rate (float, optional): Step size to use at each iteration. Each leaf weight is multiplied by this number. The smaller the value, the more conservative the weights will be. Defaults to 0.3.
  • max_depth (int, optional): Maximum depth of an individual tree. Valid values are 0 to infinity. Defaults to 5.
  • max_leaves (int, optional): Maximum number of leaves allowed on a tree. Valid values are 0 to infinity. This is the total number of final nodes. Defaults to sys.maxsize.
  • l2 (float, optional): L2 regularization term applied to the weights of the tree. Valid values are 0 to infinity. Defaults to 1.0.
  • gamma (float, optional): The minimum amount of loss required to further split a node. Valid values are 0 to infinity. Defaults to 0.0.
  • min_leaf_weight (float, optional): Minimum sum of the hessian values of the loss function required to be in a node. Defaults to 0.0.
  • base_score (float, optional): The initial prediction value of the model. Defaults to 0.5.
  • nbins (int, optional): Number of bins to calculate to partition the data. Setting this to a smaller number, will result in faster training time, while potentially sacrificing accuracy. If there are more bins, than unique values in a column, all unique values will be used. Defaults to 256.
  • parallel (bool, optional): Should multiple cores be used when training and predicting with this model? Defaults to True.

Training and Predicting

Once, the booster has been initialized, it can be fit on a provided dataset, and performance field. After fitting, the model can be used to predict on a dataset. In the case of this example, the predictions are the log odds of a given record being 1.

# Small example dataset
from seaborn import load_dataset

df = load_dataset("titanic")
X = df.select_dtypes("number").drop(column=["survived"])
y = df["survived"]

# Initialize a booster with defaults.
from forust import GradientBooster
model = GradientBooster(objective_type="LogLoss")
model.fit(X, y)

# Predict on data
model.predict(X.head())
# array([-1.94919663,  2.25863229,  0.32963671,  2.48732194, -3.00371813])

The fit method accepts the following arguments.

  • X (FrameLike): Either a pandas DataFrame, or a 2 dimensional numpy array, with numeric data.
  • y (ArrayLike): Either a pandas Series, or a 1 dimensional numpy array. If "LogLoss" was the objective type specified, then this should only contain 1 or 0 values, where 1 is the positive class being predicted. If "SquaredLoss" is the objective type, then any continuous variable can be provided.
  • sample_weight (Optional[ArrayLike], optional): Instance weights to use when training the model. If None is passed, a weight of 1 will be used for every record. Defaults to None.

The predict method accepts the following arguments.

  • X (FrameLike): Either a pandas DataFrame, or a 2 dimensional numpy array, with numeric data.
  • parallel (Optional[bool], optional): Optionally specify if the predict function should run in parallel on multiple threads. If None is passed, the parallel attribute of the booster will be used. Defaults to None.

Inspecting the Model

Once the booster has been fit, each individual tree structure can be retrieved in text form, using the text_dump method. This method returns a list, the same length as the number of trees in the model.

model.text_dump()[0]
# 0:[0 < 3] yes=1,no=2,missing=2,gain=91.50833,cover=209.388307
#       1:[4 < 13.7917] yes=3,no=4,missing=4,gain=28.185467,cover=94.00148
#             3:[1 < 18] yes=7,no=8,missing=8,gain=1.4576768,cover=22.090348
#                   7:[1 < 17] yes=15,no=16,missing=16,gain=0.691266,cover=0.705011
#                         15:leaf=-0.15120,cover=0.23500
#                         16:leaf=0.154097,cover=0.470007

The json_dump method performs the same action, but returns the model as a json representation rather than a text string.

To see an estimate for how a given feature is used in the model, the partial_dependence method is provided. This method calculates the partial dependence values of a feature. For each unique value of the feature, this gives the estimate of the predicted value for that feature, with the effects of all features averaged out. This information gives an estimate of how a given feature impacts the model.

The partial_dependence method takes the following parameters...

  • X (FrameLike): Either a pandas DataFrame, or a 2 dimensional numpy array. This should be the same data passed into the models fit, or predict, with the columns in the same order.
  • feature (Union[str, int]): The feature for which to calculate the partial dependence values. This can be the name of a column, if the provided X is a pandas DataFrame, or the index of the feature.

This method returns a 2 dimensional numpy array, where the first column is the sorted unique values of the feature, and then the second column is the partial dependence values for each feature value.

This information can be plotted to visualize how a feature is used in the model, like so.

from seaborn import lineplot
import matplotlib.pyplot as plt

pd_values = fmod.partial_dependence(X, 1)
fig = lineplot(x=pd_values[:,0], y=pd_values[:,1],)
plt.title("Partial Dependence Plot")
plt.xlabel("Age")
plt.ylabel("Log Odds")

Saving the model

To save and subsequently load a trained booster, the save_booster and load_booster methods can be used. Each accepts a path, which is used to write the model to. The model is saved and loaded as a json object.

trained_model.save_booster("model_path.json")

# To load a model from a json path.
loaded_model = GradientBooster.load_model("model_path.json")

TODOs

This is still a work in progress

  • Early stopping rounds
    • We should be able to accept a validation dataset, and this should be able to be used to determine when to stop training.
  • Monotonicity support
    • Right now features are used in the model without any constraints.
  • Ability to save a model.
    • The way the underlying trees are structured, they would lend themselves to being saved as JSon objects.
  • Clean up the CICD pipeline.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

forust-0.1.5.tar.gz (1.9 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

forust-0.1.5-cp310-none-win_amd64.whl (332.1 kB view details)

Uploaded CPython 3.10Windows x86-64

forust-0.1.5-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.whl (411.8 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.5+ x86-64

forust-0.1.5-cp310-cp310-macosx_10_7_x86_64.whl (370.4 kB view details)

Uploaded CPython 3.10macOS 10.7+ x86-64

forust-0.1.5-cp39-none-win_amd64.whl (332.1 kB view details)

Uploaded CPython 3.9Windows x86-64

forust-0.1.5-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl (411.9 kB view details)

Uploaded CPython 3.9manylinux: glibc 2.5+ x86-64

forust-0.1.5-cp39-cp39-macosx_10_7_x86_64.whl (370.5 kB view details)

Uploaded CPython 3.9macOS 10.7+ x86-64

forust-0.1.5-cp38-none-win_amd64.whl (331.3 kB view details)

Uploaded CPython 3.8Windows x86-64

forust-0.1.5-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl (412.0 kB view details)

Uploaded CPython 3.8manylinux: glibc 2.5+ x86-64

forust-0.1.5-cp38-cp38-macosx_10_7_x86_64.whl (370.4 kB view details)

Uploaded CPython 3.8macOS 10.7+ x86-64

File details

Details for the file forust-0.1.5.tar.gz.

File metadata

  • Download URL: forust-0.1.5.tar.gz
  • Upload date:
  • Size: 1.9 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/0.12.20

File hashes

Hashes for forust-0.1.5.tar.gz
Algorithm Hash digest
SHA256 69424eeeb6ddae44bfdbdf94fa3b9e7fbb58f8277e30a3955280f0281655fd9c
MD5 ca77ce8993588dc606ad28994bc227f2
BLAKE2b-256 2458ef68e240746d109318bc25f9bbbafdcc01a37b7c5a7acd0c972f9757130d

See more details on using hashes here.

File details

Details for the file forust-0.1.5-cp310-none-win_amd64.whl.

File metadata

  • Download URL: forust-0.1.5-cp310-none-win_amd64.whl
  • Upload date:
  • Size: 332.1 kB
  • Tags: CPython 3.10, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/0.12.20

File hashes

Hashes for forust-0.1.5-cp310-none-win_amd64.whl
Algorithm Hash digest
SHA256 61bbb58e7b451cbcb6372a0698b59594c71e3d6ec56e0edb8eaa3f2289e0308a
MD5 8eefc651d520866e3d93f99713a5929d
BLAKE2b-256 f38e371267bc1d33b20c11f5910dabbff839a16e860ef7d6f116b6379212fd31

See more details on using hashes here.

File details

Details for the file forust-0.1.5-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for forust-0.1.5-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 cca80d8713ff10eff4784bf65f8d2016d803f42d47ad728338e05d655e0a77da
MD5 605a1ba8ad420c38e92accd9ac47b68d
BLAKE2b-256 d4dac308a2cf0ad32538452ae9f8101332cf617cfc65467b4fbd03b8a811dad8

See more details on using hashes here.

File details

Details for the file forust-0.1.5-cp310-cp310-macosx_10_7_x86_64.whl.

File metadata

File hashes

Hashes for forust-0.1.5-cp310-cp310-macosx_10_7_x86_64.whl
Algorithm Hash digest
SHA256 5a3350c0a53950745bc3219f9379fa84105887d68bf40125129daa3cf4ffeeea
MD5 7b5b98ae1bd103584cf7362477943237
BLAKE2b-256 c9aa30b18d88bbf630abc6061586bd0d05ff15d435e977680d90e6ad42cae0bc

See more details on using hashes here.

File details

Details for the file forust-0.1.5-cp39-none-win_amd64.whl.

File metadata

  • Download URL: forust-0.1.5-cp39-none-win_amd64.whl
  • Upload date:
  • Size: 332.1 kB
  • Tags: CPython 3.9, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/0.12.20

File hashes

Hashes for forust-0.1.5-cp39-none-win_amd64.whl
Algorithm Hash digest
SHA256 7575dd13d8ef375db6bee3127cec3ee325f0d2b3c7129ecb4a5fc7fe8afb8e98
MD5 6c637fba566d3f6939b9b8a30abda0cb
BLAKE2b-256 adb996595fbed07e5143bdf4b76fff67ed3477d98df448867fc41e93bec5bbec

See more details on using hashes here.

File details

Details for the file forust-0.1.5-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for forust-0.1.5-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 42089d2c6b26c005756654d01eed8353c9ef60007f3ea9244e3395c31c9956e8
MD5 bae283089c3e8a28873fc08c82797722
BLAKE2b-256 ccd711cb14cfbe94619670adbb3787210992ed77ac0cd04c31210090dc787b6f

See more details on using hashes here.

File details

Details for the file forust-0.1.5-cp39-cp39-macosx_10_7_x86_64.whl.

File metadata

File hashes

Hashes for forust-0.1.5-cp39-cp39-macosx_10_7_x86_64.whl
Algorithm Hash digest
SHA256 814843bf5321d19c242c37fe51f6de07bcc5adda6c2ad800559e4e825186e445
MD5 e78b093b8ee6824849151a6d39b4ad38
BLAKE2b-256 a6eaeb85121b70ed6daa553c945ead5974c12d540c99e38899e35e6b510f9be9

See more details on using hashes here.

File details

Details for the file forust-0.1.5-cp38-none-win_amd64.whl.

File metadata

  • Download URL: forust-0.1.5-cp38-none-win_amd64.whl
  • Upload date:
  • Size: 331.3 kB
  • Tags: CPython 3.8, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/0.12.20

File hashes

Hashes for forust-0.1.5-cp38-none-win_amd64.whl
Algorithm Hash digest
SHA256 b2005b159a4fb2c55f0c73ef22cb371e5e531dcfbb030d61ebe8ed329591a87d
MD5 7387e55d398d231c0d14800deee7f295
BLAKE2b-256 75cb73cbb206e4ec917d1e9287c0099fad331d707c59dc657d477520b120f9fb

See more details on using hashes here.

File details

Details for the file forust-0.1.5-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for forust-0.1.5-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 b8d9425b339e0c3b4ee360a71db87ad15597deed4a9d0ce88c7512139b8dd367
MD5 8eded028bf3f86b097a55090d7e7dfb1
BLAKE2b-256 fb3c20882f64e28e0fd80249f507ac5034d926aaa0b4f590aebbf9525a2c8696

See more details on using hashes here.

File details

Details for the file forust-0.1.5-cp38-cp38-macosx_10_7_x86_64.whl.

File metadata

File hashes

Hashes for forust-0.1.5-cp38-cp38-macosx_10_7_x86_64.whl
Algorithm Hash digest
SHA256 f8cabd24d9249507c6c887a911b62ceafd90e0fa5b4cb82b09ad656f769b141b
MD5 81f071af22d68f0203d53ca0ae8e25aa
BLAKE2b-256 a272752043423371cac278965fe623e10f7c036c4f9a1ab129e3e179c5753238

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page