Official NRGBoost implementation
Project description
🔋 NRGBoost: Energy-Based Generative Boosted Trees
Official implementation of the NRGBoost algorithm.
Github: https://github.com/ajoo/nrgboost
Installation
To install the latest version of the python package run:
pip install nrgboost
NRGBoost Models
The following example shows how to train a NRGBoost model on the California Housing dataset:
import numpy as np
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
from nrgboost import Dataset, NRGBooster
# Get data
df, y = fetch_california_housing(return_X_y=True, as_frame=True)
df.insert(0, y.name, y)
df, test_df = train_test_split(df, test_size=0.2, random_state=123)
train_df, val_df = train_test_split(df, test_size=0.2, random_state=124)
# Create training set
train_ds = Dataset(train_df)
# Train model
params = {
'num_trees': 200,
'shrinkage': 0.15,
'max_leaves': 256,
'max_ratio_in_leaf': 2,
'num_model_samples': 80_000,
'p_refresh': 0.1,
'num_chains': 16,
'burn_in': 100,
}
model = NRGBooster.fit(train_ds, params, seed=1984)
Note: If your dataset contains categorical variables, they should be cast to pandas Categorical dtype before calling the Dataset constructor. For example df[categorical_col] = df[categorical_col].astype("category").
Prediction
To use the trained model for prediction we can call the predict method.
This allows the user to specify a column name for prediction.
Unlike discriminative methods, NRGBoost can be used to predict any column in the data, not just a specific "target" column.
# Do "early stopping" first:
# find the best boosting round for prediction in validation
# with cumulative=True, predict will return an iterator
# over predictions at different rounds
preds = model.predict(val_df, y.name, cumulative=True)
val_r2 = [r2_score(val_df[y.name], yh) for yh in preds]
best_round = np.argmax(val_r2)
#%% Compute test R^2 using only the first `best_round` trees
test_preds = model.predict(test_df, y.name, num_rounds=best_round)
test_r2 = r2_score(test_df[target_col], test_preds)
print('Test R^2:', test_r2)
Note: For numerical columns, NRGBoost currently predicts the expected value according to it's learned distribution. In the future we plan to make this more flexible so that the user can select a different point estimate (e.g., median or another quantile) or have access to the full distribution.
For categorical columns, NRGBoost will output logits for each possible outcome. The prediction will be an array with shape (N, K) where N is the number of points and K the cardinality of the column. The orders of each logit are determined by the pandas codes for each possible value.
The output logits are already normalized so we can convert them to probabilities simply by exponentiation (i.e., no need to softmax).
Sampling
To draw 500 samples from the model we can run:
samples_df = model.sample(500, num_steps=100)
num_steps is the number of Gibbs sampling steps that are used to generate each individual sample. It allows the user to trade-off computation time (which scales linearly in num_steps) for bias in the generated samples.
Saving and Loading
To save a NRGBoost model simply run
model.save('filename')
The model can then be loaded via:
from nrgboost import NRGBooster
model = NRGBooster.load('filename')
Cite NRGBoost
You can cite NRGBoost as:
@article{bravo2024nrgboost,
title={NRGBoost: Energy-Based Generative Boosted Trees},
author={Bravo, Jo{\~a}o},
journal={arXiv preprint arXiv:2410.03535},
year={2024}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file nrgboost-0.0.2.tar.gz.
File metadata
- Download URL: nrgboost-0.0.2.tar.gz
- Upload date:
- Size: 61.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
65643583bc4f361ad713d7bb20622d669fc2f00fd1a47306181318b9f64ddd9e
|
|
| MD5 |
ad12d1fb7c74bf99d1dd8f50c36306c4
|
|
| BLAKE2b-256 |
e7f0c204fe9c3092795e900a9eaf89ea279c4903883408ae58e3b1aa17d77274
|