Skip to main content

A lightweight gradient boosting implementation in Rust.

Project description

PyPI Crates.io

Forust

A lightweight gradient boosting package

Forust, is a lightweight package for building gradient boosted decision tree ensembles. All of the algorithm code is written in Rust, with a python wrapper. The rust package can be used directly, however, most examples shown here will be for the python wrapper. For a self contained rust example, see here. It implements the same algorithm as the XGBoost package, and in many cases will give nearly identical results.

I developed this package for a few reasons, mainly to better understand the XGBoost algorithm, additionally to have a fun project to work on in rust, and because I wanted to be able to experiment with adding new features to the algorithm in a smaller simpler codebase.

All of the rust code for the package can be found in the src directory, while all of the python wrapper code is in the py-forust directory.

Documentation

Documentation for the python API can be found here.

Installation

The package can be installed directly from pypi.

pip install forust

To use in a rust project add the following to your Cargo.toml file.

forust-ml = "0.4.8"

Usage

For details on all of the methods and their respective parameters, see the python api documentation.

The GradientBooster class is currently the only public facing class in the package, and can be used to train gradient boosted decision tree ensembles with multiple objective functions.

Training and Predicting

Once, the booster has been initialized, it can be fit on a provided dataset, and performance field. After fitting, the model can be used to predict on a dataset. In the case of this example, the predictions are the log odds of a given record being 1.

# Small example dataset
from seaborn import load_dataset

df = load_dataset("titanic")
X = df.select_dtypes("number").drop(columns=["survived"])
y = df["survived"]

# Initialize a booster with defaults.
from forust import GradientBooster
model = GradientBooster(objective_type="LogLoss")
model.fit(X, y)

# Predict on data
model.predict(X.head())
# array([-1.94919663,  2.25863229,  0.32963671,  2.48732194, -3.00371813])

# predict contributions
model.predict_contributions(X.head())
# array([[-0.63014213,  0.33880048, -0.16520798, -0.07798772, -0.85083578,
#        -1.07720813],
#       [ 1.05406709,  0.08825999,  0.21662544, -0.12083538,  0.35209258,
#        -1.07720813],

When predicting with the data, the maximum iteration that will be used when predicting can be set using the set_prediction_iteration method. If early_stopping_rounds has been set, this will default to the best iteration, otherwise all of the trees will be used.

If early stopping was used, the evaluation history can be retrieved with the get_evaluation_history method.

model = GradientBooster(objective_type="LogLoss")
model.fit(X, y, evaluation_data=[(X, y)])

model.get_evaluation_history()[0:3]

# array([[588.9158873 ],
#        [532.01055803],
#        [496.76933646]])

Inspecting the Model

Once the booster has been fit, each individual tree structure can be retrieved in text form, using the text_dump method. This method returns a list, the same length as the number of trees in the model.

model.text_dump()[0]
# 0:[0 < 3] yes=1,no=2,missing=2,gain=91.50833,cover=209.388307
#       1:[4 < 13.7917] yes=3,no=4,missing=4,gain=28.185467,cover=94.00148
#             3:[1 < 18] yes=7,no=8,missing=8,gain=1.4576768,cover=22.090348
#                   7:[1 < 17] yes=15,no=16,missing=16,gain=0.691266,cover=0.705011
#                         15:leaf=-0.15120,cover=0.23500
#                         16:leaf=0.154097,cover=0.470007

The json_dump method performs the same action, but returns the model as a json representation rather than a text string.

To see an estimate for how a given feature is used in the model, the partial_dependence method is provided. This method calculates the partial dependence values of a feature. For each unique value of the feature, this gives the estimate of the predicted value for that feature, with the effects of all features averaged out. This information gives an estimate of how a given feature impacts the model.

This information can be plotted to visualize how a feature is used in the model, like so.

from seaborn import lineplot
import matplotlib.pyplot as plt

pd_values = model.partial_dependence(X=X, feature="age", samples=None)

fig = lineplot(x=pd_values[:,0], y=pd_values[:,1],)
plt.title("Partial Dependence Plot")
plt.xlabel("Age")
plt.ylabel("Log Odds")

We can see how this is impacted if a model is created, where a specific constraint is applied to the feature using the monotone_constraint parameter.

model = GradientBooster(
    objective_type="LogLoss",
    monotone_constraints={"age": -1},
)
model.fit(X, y)

pd_values = model.partial_dependence(X=X, feature="age")
fig = lineplot(
    x=pd_values[:, 0],
    y=pd_values[:, 1],
)
plt.title("Partial Dependence Plot with Monotonicity")
plt.xlabel("Age")
plt.ylabel("Log Odds")

Feature importance values can be calculated with the calculate_feature_importance method. This function will return a dictionary of the features and their importances. It should be noted that if a feature was never used for splitting it will not be returned in importance dictionary. This function takes the following arguments.

model.calculate_feature_importance("Gain")
# {
#   'parch': 0.0713072270154953, 
#   'age': 0.11609109491109848,
#   'sibsp': 0.1486879289150238,
#   'fare': 0.14309120178222656,
#   'pclass': 0.5208225250244141
# }

Saving the model

To save and subsequently load a trained booster, the save_booster and load_booster methods can be used. Each accepts a path, which is used to write the model to. The model is saved and loaded as a json object.

trained_model.save_booster("model_path.json")

# To load a model from a json path.
loaded_model = GradientBooster.load_booster("model_path.json")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

forust-0.4.8.tar.gz (1.4 MB view details)

Uploaded Source

Built Distributions

forust-0.4.8-cp312-none-win_amd64.whl (472.8 kB view details)

Uploaded CPython 3.12 Windows x86-64

forust-0.4.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (573.6 kB view details)

Uploaded CPython 3.12 manylinux: glibc 2.17+ x86-64

forust-0.4.8-cp312-cp312-macosx_10_12_x86_64.whl (524.0 kB view details)

Uploaded CPython 3.12 macOS 10.12+ x86-64

forust-0.4.8-cp311-none-win_amd64.whl (473.1 kB view details)

Uploaded CPython 3.11 Windows x86-64

forust-0.4.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (574.8 kB view details)

Uploaded CPython 3.11 manylinux: glibc 2.17+ x86-64

forust-0.4.8-cp311-cp311-macosx_10_12_x86_64.whl (525.9 kB view details)

Uploaded CPython 3.11 macOS 10.12+ x86-64

forust-0.4.8-cp310-none-win_amd64.whl (473.2 kB view details)

Uploaded CPython 3.10 Windows x86-64

forust-0.4.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (574.8 kB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ x86-64

forust-0.4.8-cp310-cp310-macosx_10_12_x86_64.whl (525.9 kB view details)

Uploaded CPython 3.10 macOS 10.12+ x86-64

forust-0.4.8-cp39-none-win_amd64.whl (474.0 kB view details)

Uploaded CPython 3.9 Windows x86-64

forust-0.4.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (574.7 kB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ x86-64

forust-0.4.8-cp39-cp39-macosx_10_12_x86_64.whl (526.1 kB view details)

Uploaded CPython 3.9 macOS 10.12+ x86-64

forust-0.4.8-cp38-none-win_amd64.whl (473.5 kB view details)

Uploaded CPython 3.8 Windows x86-64

forust-0.4.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (574.9 kB view details)

Uploaded CPython 3.8 manylinux: glibc 2.17+ x86-64

forust-0.4.8-cp38-cp38-macosx_10_12_x86_64.whl (526.0 kB view details)

Uploaded CPython 3.8 macOS 10.12+ x86-64

File details

Details for the file forust-0.4.8.tar.gz.

File metadata

  • Download URL: forust-0.4.8.tar.gz
  • Upload date:
  • Size: 1.4 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.5.1

File hashes

Hashes for forust-0.4.8.tar.gz
Algorithm Hash digest
SHA256 68db9d231c06e10d8e3ccbf9129cf121e5eb24b8109c04decce16b0f84347634
MD5 99ecfbb0e3f66355d717045c0366ec30
BLAKE2b-256 f45ccc59c7e70c14dfebb5edb211f76bb5636863928e1fd60cfbf1acb817b65b

See more details on using hashes here.

File details

Details for the file forust-0.4.8-cp312-none-win_amd64.whl.

File metadata

File hashes

Hashes for forust-0.4.8-cp312-none-win_amd64.whl
Algorithm Hash digest
SHA256 8778d68fcf55a50569299b6b90c7734f46622abf3810a91dd147d6b4654ce70b
MD5 165f75cd1221ea4c041f889bc1baf0c5
BLAKE2b-256 350400f02ddd2051cad356698d98d8d85f3513127f8b832c1096d4ba604a820b

See more details on using hashes here.

File details

Details for the file forust-0.4.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for forust-0.4.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 9b69c36a11890eea02aa0dd3895c828c8c58282c18e9353cc4be11cce8a01e43
MD5 ac8c94d29641078ad4127b4dac6d956e
BLAKE2b-256 2757b915a21d9c0b7e1f5fc38da95e485ad4a670de158be340ab2675bc4dd26b

See more details on using hashes here.

File details

Details for the file forust-0.4.8-cp312-cp312-macosx_10_12_x86_64.whl.

File metadata

File hashes

Hashes for forust-0.4.8-cp312-cp312-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 4c813a88ff64c2379c94d0f4dbd5231eba58e51010f97181f5d4aaec74724b93
MD5 6ebf314a2db07d9fa3a0cfe0dfc61be0
BLAKE2b-256 0390ea9a63a9f621d881c153f2e2dc2d4dd35dc409423bc0c2a93761cbc4d303

See more details on using hashes here.

File details

Details for the file forust-0.4.8-cp311-none-win_amd64.whl.

File metadata

File hashes

Hashes for forust-0.4.8-cp311-none-win_amd64.whl
Algorithm Hash digest
SHA256 8276ac4d072745b2480adbdfca4fe92c79f3d18f4c254351745a81b6de11234b
MD5 11df41317995b011b758a0b2335ea8c8
BLAKE2b-256 e562219134a8297de497e16751903bcc074b8fa0cbc59b474ea3fd0b310e2d74

See more details on using hashes here.

File details

Details for the file forust-0.4.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for forust-0.4.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 00eeb1f3939ec0e14a87d4c0b22487e6a33636a0185772f9c929773a6123a98b
MD5 2b86f75b34483515e503d8b2bd6b3d8b
BLAKE2b-256 2cdf9422c689306ed035b820748a1c909c61fc72638bca9dd991856dc1064988

See more details on using hashes here.

File details

Details for the file forust-0.4.8-cp311-cp311-macosx_10_12_x86_64.whl.

File metadata

File hashes

Hashes for forust-0.4.8-cp311-cp311-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 ac6712fe20a310afae318ed1ad47925a3e1560a4596b4edc6fb327e65e698479
MD5 aaee261ec87562a621244126d1c7fe9b
BLAKE2b-256 cbe36a1dd24c47b6a2a6c3a09a7fcf88b4b92aa922241998baa3c153d0da9ff0

See more details on using hashes here.

File details

Details for the file forust-0.4.8-cp310-none-win_amd64.whl.

File metadata

File hashes

Hashes for forust-0.4.8-cp310-none-win_amd64.whl
Algorithm Hash digest
SHA256 2df04863c899af383c6a95683232b96cdf9bf14f7502bee38ad3ac3d8a2bb95f
MD5 0d3f4eb19ba059e3475c3a5a0ddf0fd8
BLAKE2b-256 82950cb9f00c4c16b4d196de52f5a1cbb4cad36b41a8747466aad69c6e5fd45a

See more details on using hashes here.

File details

Details for the file forust-0.4.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for forust-0.4.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 ef8908ded756dc048f3952e9d55a5f4cf1874f5d708bd27b5a1cdee55a70e4e7
MD5 bfa3538dfae679a2ee3b635742f60816
BLAKE2b-256 a0d62a915c249d95f71839987263366b07b506a842d311e9bfba49e5e1b953fc

See more details on using hashes here.

File details

Details for the file forust-0.4.8-cp310-cp310-macosx_10_12_x86_64.whl.

File metadata

File hashes

Hashes for forust-0.4.8-cp310-cp310-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 294e73ff6b78b1fbaef9d4b9803b42d6e6af32c57f220bc1edbeff2ecb823f3e
MD5 428021464cec0aeae3e3fceae20dbfe2
BLAKE2b-256 5d6b98ca1cf1b661d610a53a97392d2430ba5057e9e1d62e8a2e15c928e9c64c

See more details on using hashes here.

File details

Details for the file forust-0.4.8-cp39-none-win_amd64.whl.

File metadata

  • Download URL: forust-0.4.8-cp39-none-win_amd64.whl
  • Upload date:
  • Size: 474.0 kB
  • Tags: CPython 3.9, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.5.1

File hashes

Hashes for forust-0.4.8-cp39-none-win_amd64.whl
Algorithm Hash digest
SHA256 812e96d7002138de86229901d95cc56263989e87c25f5fb360d80cec963bdc9b
MD5 7539687cddef829b967301878a42cc80
BLAKE2b-256 7a9d4d12c3f04b768422fabdb9b90edbc8279d1238e5d262108ccba8d42272de

See more details on using hashes here.

File details

Details for the file forust-0.4.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for forust-0.4.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 96a0683df370032d939a2f8c93c2e2287c9e0db4e614ba86ef82664f3c2bfb27
MD5 1b817860b7db857acca7b21011385ca8
BLAKE2b-256 afd412b56b3a49403259224623a61145ac6f5c849d5b59b740341ef4ec3b2cbc

See more details on using hashes here.

File details

Details for the file forust-0.4.8-cp39-cp39-macosx_10_12_x86_64.whl.

File metadata

File hashes

Hashes for forust-0.4.8-cp39-cp39-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 d5c4a519f6418e1263ea030c718d7f79be14c4684b22a59b2fab7033e0db6fca
MD5 6efcfb6cdf99e02532b64380629d6e18
BLAKE2b-256 11cd80a23d4ff36ed13d1faceb135e7b53b480941449181f6a0280196e56dbbf

See more details on using hashes here.

File details

Details for the file forust-0.4.8-cp38-none-win_amd64.whl.

File metadata

  • Download URL: forust-0.4.8-cp38-none-win_amd64.whl
  • Upload date:
  • Size: 473.5 kB
  • Tags: CPython 3.8, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.5.1

File hashes

Hashes for forust-0.4.8-cp38-none-win_amd64.whl
Algorithm Hash digest
SHA256 a088e533bce5eca803fe500cfede4349bef4ed47eab7f73a149d6e3513845061
MD5 252085bdaf21bcc080e7fbdc554d9773
BLAKE2b-256 4bde3c53e9d95d77d4f6a24fa90b81f2977389a4a1c676ab40d474eb28487e3f

See more details on using hashes here.

File details

Details for the file forust-0.4.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for forust-0.4.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 de335fab168931030a1788452d8092a5af0a71c4fa149ac447402ebf7999455b
MD5 9c0789f9bb217ab4c908a60ef0dcf123
BLAKE2b-256 fc707a1ad6db3672ce4cd15c97947a25de27f160f1e007551c560be021414b39

See more details on using hashes here.

File details

Details for the file forust-0.4.8-cp38-cp38-macosx_10_12_x86_64.whl.

File metadata

File hashes

Hashes for forust-0.4.8-cp38-cp38-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 b0b6540eb71a0a8d43f274e593551b7063603114440e274e1e690ef43ac28126
MD5 1b464ba1f5fa3311b91caed2b8ba840b
BLAKE2b-256 a1e891968a080ef430323ced1f61bd8874724c0a93e169f192bd6659032d0079

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page