A lightweight gradient boosting implementation in Rust.
Project description
Forust
A lightweight gradient boosting package
Forust, is a lightweight package for building gradient boosted decision tree ensembles. All of the algorithm code is written in Rust, with a python wrapper. The rust package can be used directly, however, most examples shown here will be for the python wrapper. It implements the same algorithm as the XGBoost package, and in many cases will give nearly identical results.
I developed this package for a few reasons, mainly to better understand the XGBoost algorithm, additionally to have a fun project to work on in rust, and because I wanted to be able to experiment with adding new features to the algorithm in a smaller simpler codebase.
Usage
The GradientBooster
class is currently the only public facing class in the package, and can be used to train gradient boosted decision tree ensembles with multiple objective functions.
It can be initialized with the following arguments.
objective_type
(str, optional): The name of objective function used to optimize. Valid options include "LogLoss" to use logistic loss as the objective function, or "SquaredLoss" to use Squared Error as the objective function. Defaults to "LogLoss".iterations
(int, optional): Total number of trees to train in the ensemble. Defaults to 100.learning_rate
(float, optional): Step size to use at each iteration. Each leaf weight is multiplied by this number. The smaller the value, the more conservative the weights will be. Defaults to 0.3.max_depth
(int, optional): Maximum depth of an individual tree. Valid values are 0 to infinity. Defaults to 5.max_leaves
(int, optional): Maximum number of leaves allowed on a tree. Valid values are 0 to infinity. This is the total number of final nodes. Defaults to sys.maxsize.l2
(float, optional): L2 regularization term applied to the weights of the tree. Valid values are 0 to infinity. Defaults to 1.0.gamma
(float, optional): The minimum amount of loss required to further split a node. Valid values are 0 to infinity. Defaults to 0.0.min_leaf_weight
(float, optional): Minimum sum of the hessian values of the loss function required to be in a node. Defaults to 0.0.base_score
(float, optional): The initial prediction value of the model. Defaults to 0.5.nbins
(int, optional): Number of bins to calculate to partition the data. Setting this to a smaller number, will result in faster training time, while potentially sacrificing accuracy. If there are more bins, than unique values in a column, all unique values will be used. Defaults to 256.parallel
(bool, optional): Should multiple cores be used when training and predicting with this model? Defaults to True.dtype
(Union[np.dtype, str], optional): Datatype used for the model. Valid options are a numpy 32 bit float, or numpy 64 bit float. Using 32 bit float could be faster in some instances, however this may lead to less precise results. Defaults to "float64".
Training and Predicting
Once, the booster has been initialized, it can be fit on a provided dataset, and performance field. After fitting, the model can be used to predict on a dataset. In the case of this example, the predictions are the log odds of a given record being 1.
# Small example dataset
from seaborn import load_dataset
df = load_dataset("titanic")
X = df.select_dtypes("number").drop(column=["survived"])
y = df["survived"]
# Initialize a booster with defaults.
from forust import GradientBooster
model = GradientBooster(objective_type="LogLoss")
model.fit(X, y)
# Predict on data
model.predict(X.head())
# array([-1.94919663, 2.25863229, 0.32963671, 2.48732194, -3.00371813])
The fit
method accepts the following arguments.
X
(FrameLike): Either a pandas DataFrame, or a 2 dimensional numpy array, with numeric data.y
(ArrayLike): Either a pandas Series, or a 1 dimensional numpy array.sample_weight
(Optional[ArrayLike], optional): Instance weights to use when training the model. If None is passed, a weight of 1 will be used for every record. Defaults to None.
The predict method accepts the following arguments.
X
(FrameLike): Either a pandas DataFrame, or a 2 dimensional numpy array, with numeric data.
Inspecting the Model
Once the booster has been fit, each individual tree structure can be retrieved in text form, using the text_dump
method. This method returns a list, the same length as the number of trees in the model.
model.text_dump()[0]
# 0:[0 < 3] yes=1,no=2,missing=2,gain=91.50833,cover=209.388307
# 1:[4 < 13.7917] yes=3,no=4,missing=4,gain=28.185467,cover=94.00148
# 3:[1 < 18] yes=7,no=8,missing=8,gain=1.4576768,cover=22.090348
# 7:[1 < 17] yes=15,no=16,missing=16,gain=0.691266,cover=0.705011
# 15:leaf=-0.15120,cover=0.23500
# 16:leaf=0.154097,cover=0.470007
The json_dump
method performs the same action, but returns the model as a json representation rather than a text string.
Saving the model
To save and subsequently load a trained booster, the save_booster
and load_booster
methods can be used. Each accepts a path, which is used to write the model to. The model is saved and loaded as a json object.
trained_model.save_booster("model_path.json")
# To load a model from a json path.
loaded_model = GradientBooster.load_model("model_path.json")
TODOs
This is still a work in progress
- Early stopping rounds
- We should be able to accept a validation dataset, and this should be able to be used to determine when to stop training.
- Monotonicity support
- Right now features are used in the model without any constraints.
- Ability to save a model.
- The way the underlying trees are structured, they would lend themselves to being saved as JSon objects.
- Clean up the CICD pipeline.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Hashes for forust-0.1.2-cp310-none-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 503f4578ae3971cfc04dc4eb51ef7ac69ead0498eecc99e250bfd3d5b8cde2c5 |
|
MD5 | aff20d7c19218c231c46e8cdab75ac73 |
|
BLAKE2b-256 | 46c779ba72d1d10e3f58b590458c8c7425570dbd9924edba9f3a9a593208bf10 |
Hashes for forust-0.1.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 77c7aef530c0e32dbe40fbf2ba57c8beca1c733998b81802a53ae18491570029 |
|
MD5 | 5710631eb14d46c27d40673907a68d3c |
|
BLAKE2b-256 | 22b7c4843169365bf9a4414331b8c8bc6dbcc7736c1e57d0410db32f465e40a2 |
Hashes for forust-0.1.2-cp310-cp310-macosx_10_7_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9266e77a209ca9bb220e05cc831fed892dff50e412d0122acad02fe54d2e0ac8 |
|
MD5 | da7e321014d46583876c7b024816c019 |
|
BLAKE2b-256 | ff9342842fbefa4a7e806fa9b0c5af32c7857e76628e907faa554d4728d5ff9b |
Hashes for forust-0.1.2-cp39-none-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fae067a64a7307a8089b410a11e57b305ae7cb38977f3b979c53816f6f09d82b |
|
MD5 | 2ecbf78180440e8e70b8db1c23363fe8 |
|
BLAKE2b-256 | 882427c722f916eb244676dc2ad6839bdd370d024bdb11555df0928313a98ac5 |
Hashes for forust-0.1.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 03740753e0c4107519ebee0aa0381dfab8937ddd1de42bf82af8658f92bf8578 |
|
MD5 | f3524dcbbfd2a56a9c1bf6f9ea479390 |
|
BLAKE2b-256 | 3aa3ccd4adbee2fc95d72f7d5738faaed04a9b695c4f0962137406ca35a329f3 |
Hashes for forust-0.1.2-cp39-cp39-macosx_10_7_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ae76404ff3db8546800880d7566b33c635ec8268bd1ad186fdc71f561c92f4aa |
|
MD5 | 67f4a9242d4037824d5cf679d4407ef4 |
|
BLAKE2b-256 | da8e8b900f333a1ff88de05e745190c832afa4563fe1bf30e83ec066290db4a0 |
Hashes for forust-0.1.2-cp38-none-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ab3655f1a44e7cef5dff4772d837c6f09eb4c8658f43bb2ae7215099a02a7dba |
|
MD5 | c5766ff75027f1810afdc01fa6cd9c52 |
|
BLAKE2b-256 | f4738b4fe5f8c5e4fd6a8f0401953c777e4701bbdcfa6f34abf0e8453f48b8c4 |
Hashes for forust-0.1.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5ade438ad095f506dd80054659806a76458d0101e582ac51a0a2b2450fc9d26b |
|
MD5 | fc849896633e0b4f755460cab2d0ac12 |
|
BLAKE2b-256 | c635edf88b4d0cf6f60d979dffb8e2b382ed143bd05fb2805e16a05b719cc212 |
Hashes for forust-0.1.2-cp38-cp38-macosx_10_7_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 57c7af0ec59d310ec4590ccf75d174e49d55cbbc509245ccab28b50a8101c5db |
|
MD5 | f349f8637acaada496c5a36fee56d6ac |
|
BLAKE2b-256 | ad6d65ad0f47acaea4adc7ec80ef5f40f1d0d1349bf70491210b4d026932c8b1 |