Official NRGBoost implementation
Project description
🔋 NRGBoost: Energy-Based Generative Boosted Trees
Official implementation of the NRGBoost algorithm.
Github: https://github.com/ajoo/nrgboost
Installation
To install the latest version of the python package run:
pip install nrgboost
Note: Requires python 3.10+. Prebuilt wheels are now available for Linux and macOS. Windows is not supported for now.
Building from source
To install from a source distribution you need a C compiler with OpenMP. On macOS, install Homebrew's libomp first:
brew install libomp
NRGBoost Models
The following example shows how to train a NRGBoost model on the California Housing dataset:
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from nrgboost import Dataset, NRGBooster
# Get data
df, y = fetch_california_housing(return_X_y=True, as_frame=True)
df.insert(0, y.name, y)
df, test_df = train_test_split(df, test_size=0.2, random_state=123)
train_df, val_df = train_test_split(df, test_size=0.2, random_state=124)
# Create training set
train_ds = Dataset(train_df)
# Train model
params = {
'num_trees': 200,
'shrinkage': 0.15,
'max_leaves': 256,
'max_ratio_in_leaf': 2,
'num_model_samples': 80_000,
'p_refresh': 0.1,
'num_chains': 16,
'burn_in': 100,
}
model = NRGBooster.fit(train_ds, params, seed=1984)
Note: If your dataset contains categorical variables, they should be cast to pandas Categorical dtype before calling the Dataset constructor. For example df[categorical_col] = df[categorical_col].astype("category").
Prediction
To use the trained model for prediction we can call the predict method.
This allows the user to specify a column name for prediction.
Unlike discriminative methods, NRGBoost can be used to predict any column in the data, not just a specific "target" column.
import numpy as np
from sklearn.metrics import r2_score
# Do "early stopping" first:
# find the best boosting round for prediction in validation
# with cumulative=True, predict will return an iterator
# over predictions at different rounds
preds = model.predict(val_df, y.name, cumulative=True)
val_r2 = [r2_score(val_df[y.name], yh) for yh in preds]
best_round = np.argmax(val_r2)
#%% Compute test R^2 using only the first `best_round` trees
test_preds = model.predict(test_df, y.name, num_rounds=best_round)
test_r2 = r2_score(test_df[y.name], test_preds)
print('Test R^2:', test_r2)
Note: For numerical columns, NRGBoost currently predicts the expected value according to it's learned distribution. In the future we plan to make this more flexible so that the user can select a different point estimate (e.g., median or another quantile) or have access to the full distribution.
For categorical columns, NRGBoost will output logits for each possible outcome. The prediction will be an array with shape (N, K) where N is the number of points and K the cardinality of the column. The orders of each logit are determined by the pandas codes for each possible value.
By default, the output logits are already normalized so they can be converted to probabilities directly by exponentiation (i.e., there is no need to apply softmax since their partition function is already normalized to 1).
Sampling
To draw 500 samples from the model using only the first best_round trees run:
samples_df = model.sample(500, num_steps=100, num_rounds=best_round)
num_steps is the number of Gibbs sampling steps that are used to generate each individual sample. It allows the user to trade-off computation time (which scales linearly in num_steps) for bias in the generated samples.
Saving and Loading
To save a NRGBoost model run
model.save('filename')
The saved model can then be loaded via:
from nrgboost import NRGBooster
model = NRGBooster.load('filename')
Cite NRGBoost
You can cite NRGBoost as:
@inproceedings{bravo2025nrgboost,
title={{NRGB}oost: Energy-Based Generative Boosted Trees},
author={Jo{\~a}o Bravo},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=wQHyjIZ1SH}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file nrgboost-0.0.3.tar.gz.
File metadata
- Download URL: nrgboost-0.0.3.tar.gz
- Upload date:
- Size: 57.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7b9e6a2a951755a75f34f1ec1185e82c4038938de6d126b046d46ce0624bbda0
|
|
| MD5 |
9bc8b5a37f4de87a4304b03661378727
|
|
| BLAKE2b-256 |
de3657e57957b5c66d9c86021bf3b26f6b3c6a9d9810af55997ee11b7f2603cb
|
Provenance
The following attestation bundles were made for nrgboost-0.0.3.tar.gz:
Publisher:
publish.yaml on Ajoo/nrgboost
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nrgboost-0.0.3.tar.gz -
Subject digest:
7b9e6a2a951755a75f34f1ec1185e82c4038938de6d126b046d46ce0624bbda0 - Sigstore transparency entry: 2014018905
- Sigstore integration time:
-
Permalink:
Ajoo/nrgboost@feef73a3edb20b911c2f7214b13f810909ef20ad -
Branch / Tag:
refs/tags/v0.0.3 - Owner: https://github.com/Ajoo
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yaml@feef73a3edb20b911c2f7214b13f810909ef20ad -
Trigger Event:
push
-
Statement type:
File details
Details for the file nrgboost-0.0.3-cp313-cp313-manylinux_2_28_x86_64.whl.
File metadata
- Download URL: nrgboost-0.0.3-cp313-cp313-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 201.4 kB
- Tags: CPython 3.13, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6d520c6b2a847a44f038604c82ea548e181bdaf6278ed1e265e8acad7e3c4b17
|
|
| MD5 |
b99630b0401c8ac0dcee4a275bc78cc1
|
|
| BLAKE2b-256 |
413935cd2fa51a6746d762d7016917b8bcf05ba5c8b4ae026a5ee36cd0ea8ec1
|
Provenance
The following attestation bundles were made for nrgboost-0.0.3-cp313-cp313-manylinux_2_28_x86_64.whl:
Publisher:
publish.yaml on Ajoo/nrgboost
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nrgboost-0.0.3-cp313-cp313-manylinux_2_28_x86_64.whl -
Subject digest:
6d520c6b2a847a44f038604c82ea548e181bdaf6278ed1e265e8acad7e3c4b17 - Sigstore transparency entry: 2014020212
- Sigstore integration time:
-
Permalink:
Ajoo/nrgboost@feef73a3edb20b911c2f7214b13f810909ef20ad -
Branch / Tag:
refs/tags/v0.0.3 - Owner: https://github.com/Ajoo
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yaml@feef73a3edb20b911c2f7214b13f810909ef20ad -
Trigger Event:
push
-
Statement type:
File details
Details for the file nrgboost-0.0.3-cp313-cp313-macosx_14_0_arm64.whl.
File metadata
- Download URL: nrgboost-0.0.3-cp313-cp313-macosx_14_0_arm64.whl
- Upload date:
- Size: 303.2 kB
- Tags: CPython 3.13, macOS 14.0+ ARM64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
515a764bf9d8a3deab15d3ea330776c3508d29fa5746d1cda099d7881c46c906
|
|
| MD5 |
67f5713555781b67b1f27a0282bcff42
|
|
| BLAKE2b-256 |
eb582d163dee3dce4955ea769ed13fdf3583f7bb25db9393267f8fbdf383c471
|
Provenance
The following attestation bundles were made for nrgboost-0.0.3-cp313-cp313-macosx_14_0_arm64.whl:
Publisher:
publish.yaml on Ajoo/nrgboost
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nrgboost-0.0.3-cp313-cp313-macosx_14_0_arm64.whl -
Subject digest:
515a764bf9d8a3deab15d3ea330776c3508d29fa5746d1cda099d7881c46c906 - Sigstore transparency entry: 2014019983
- Sigstore integration time:
-
Permalink:
Ajoo/nrgboost@feef73a3edb20b911c2f7214b13f810909ef20ad -
Branch / Tag:
refs/tags/v0.0.3 - Owner: https://github.com/Ajoo
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yaml@feef73a3edb20b911c2f7214b13f810909ef20ad -
Trigger Event:
push
-
Statement type:
File details
Details for the file nrgboost-0.0.3-cp312-cp312-manylinux_2_28_x86_64.whl.
File metadata
- Download URL: nrgboost-0.0.3-cp312-cp312-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 201.4 kB
- Tags: CPython 3.12, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4f402d792e6ec6691f5b6204591c30d511835cb6d944fc98cc3851b883c4a0c2
|
|
| MD5 |
6850d5ebefba838a177e2a43f03deecd
|
|
| BLAKE2b-256 |
b5d7e316415f8cf725cb4d7c4e9550c90a31f92f6984cad9cb4e507e1d5c7160
|
Provenance
The following attestation bundles were made for nrgboost-0.0.3-cp312-cp312-manylinux_2_28_x86_64.whl:
Publisher:
publish.yaml on Ajoo/nrgboost
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nrgboost-0.0.3-cp312-cp312-manylinux_2_28_x86_64.whl -
Subject digest:
4f402d792e6ec6691f5b6204591c30d511835cb6d944fc98cc3851b883c4a0c2 - Sigstore transparency entry: 2014019879
- Sigstore integration time:
-
Permalink:
Ajoo/nrgboost@feef73a3edb20b911c2f7214b13f810909ef20ad -
Branch / Tag:
refs/tags/v0.0.3 - Owner: https://github.com/Ajoo
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yaml@feef73a3edb20b911c2f7214b13f810909ef20ad -
Trigger Event:
push
-
Statement type:
File details
Details for the file nrgboost-0.0.3-cp312-cp312-macosx_14_0_arm64.whl.
File metadata
- Download URL: nrgboost-0.0.3-cp312-cp312-macosx_14_0_arm64.whl
- Upload date:
- Size: 303.2 kB
- Tags: CPython 3.12, macOS 14.0+ ARM64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a195038a7a1b5932d4649cec1bd9c0967a69c7f00a60b278226ae9a947d21dc6
|
|
| MD5 |
6dea93c8e85ff47024cc90ad5be26d4c
|
|
| BLAKE2b-256 |
313e16ce52e4bc5aa2d6bbb0a43a34e90fa4c45eb329a0453944bac624cdaa79
|
Provenance
The following attestation bundles were made for nrgboost-0.0.3-cp312-cp312-macosx_14_0_arm64.whl:
Publisher:
publish.yaml on Ajoo/nrgboost
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nrgboost-0.0.3-cp312-cp312-macosx_14_0_arm64.whl -
Subject digest:
a195038a7a1b5932d4649cec1bd9c0967a69c7f00a60b278226ae9a947d21dc6 - Sigstore transparency entry: 2014019320
- Sigstore integration time:
-
Permalink:
Ajoo/nrgboost@feef73a3edb20b911c2f7214b13f810909ef20ad -
Branch / Tag:
refs/tags/v0.0.3 - Owner: https://github.com/Ajoo
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yaml@feef73a3edb20b911c2f7214b13f810909ef20ad -
Trigger Event:
push
-
Statement type:
File details
Details for the file nrgboost-0.0.3-cp311-cp311-manylinux_2_28_x86_64.whl.
File metadata
- Download URL: nrgboost-0.0.3-cp311-cp311-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 200.8 kB
- Tags: CPython 3.11, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dfe30829ceaf2d0d0ec03eab1744838bed857d56919238e7243c9fb7f273e1fb
|
|
| MD5 |
b852ac819d9ae897bc493763e02131ea
|
|
| BLAKE2b-256 |
d02d1b8f3947211614089ca6ff6913a0f0b49cb6a206800425fc7e7558b0657a
|
Provenance
The following attestation bundles were made for nrgboost-0.0.3-cp311-cp311-manylinux_2_28_x86_64.whl:
Publisher:
publish.yaml on Ajoo/nrgboost
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nrgboost-0.0.3-cp311-cp311-manylinux_2_28_x86_64.whl -
Subject digest:
dfe30829ceaf2d0d0ec03eab1744838bed857d56919238e7243c9fb7f273e1fb - Sigstore transparency entry: 2014019056
- Sigstore integration time:
-
Permalink:
Ajoo/nrgboost@feef73a3edb20b911c2f7214b13f810909ef20ad -
Branch / Tag:
refs/tags/v0.0.3 - Owner: https://github.com/Ajoo
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yaml@feef73a3edb20b911c2f7214b13f810909ef20ad -
Trigger Event:
push
-
Statement type:
File details
Details for the file nrgboost-0.0.3-cp311-cp311-macosx_14_0_arm64.whl.
File metadata
- Download URL: nrgboost-0.0.3-cp311-cp311-macosx_14_0_arm64.whl
- Upload date:
- Size: 303.1 kB
- Tags: CPython 3.11, macOS 14.0+ ARM64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
691c76e494378363435d55cca87e57be3e11baa6ef519ebb231341e42f5dd9c1
|
|
| MD5 |
fdb4c96f3dd1a40eb253e1ceb1528f3b
|
|
| BLAKE2b-256 |
d649b967c4149671d7228b1346d0a9e6437b5a321502cbf569ac6a5116dd6e01
|
Provenance
The following attestation bundles were made for nrgboost-0.0.3-cp311-cp311-macosx_14_0_arm64.whl:
Publisher:
publish.yaml on Ajoo/nrgboost
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nrgboost-0.0.3-cp311-cp311-macosx_14_0_arm64.whl -
Subject digest:
691c76e494378363435d55cca87e57be3e11baa6ef519ebb231341e42f5dd9c1 - Sigstore transparency entry: 2014019578
- Sigstore integration time:
-
Permalink:
Ajoo/nrgboost@feef73a3edb20b911c2f7214b13f810909ef20ad -
Branch / Tag:
refs/tags/v0.0.3 - Owner: https://github.com/Ajoo
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yaml@feef73a3edb20b911c2f7214b13f810909ef20ad -
Trigger Event:
push
-
Statement type:
File details
Details for the file nrgboost-0.0.3-cp310-cp310-manylinux_2_28_x86_64.whl.
File metadata
- Download URL: nrgboost-0.0.3-cp310-cp310-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 200.8 kB
- Tags: CPython 3.10, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6a2a91cc51fb6cd96dc7617083e75d1171f8c12250c93981ddea0a9041c9933d
|
|
| MD5 |
a2251fe9b77fbe58229e9bb461276764
|
|
| BLAKE2b-256 |
cc2a83cb6b200ab1c5aa3962f21ebc02f0755fbaff374dd0bfaebb1c550aa181
|
Provenance
The following attestation bundles were made for nrgboost-0.0.3-cp310-cp310-manylinux_2_28_x86_64.whl:
Publisher:
publish.yaml on Ajoo/nrgboost
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nrgboost-0.0.3-cp310-cp310-manylinux_2_28_x86_64.whl -
Subject digest:
6a2a91cc51fb6cd96dc7617083e75d1171f8c12250c93981ddea0a9041c9933d - Sigstore transparency entry: 2014020331
- Sigstore integration time:
-
Permalink:
Ajoo/nrgboost@feef73a3edb20b911c2f7214b13f810909ef20ad -
Branch / Tag:
refs/tags/v0.0.3 - Owner: https://github.com/Ajoo
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yaml@feef73a3edb20b911c2f7214b13f810909ef20ad -
Trigger Event:
push
-
Statement type:
File details
Details for the file nrgboost-0.0.3-cp310-cp310-macosx_14_0_arm64.whl.
File metadata
- Download URL: nrgboost-0.0.3-cp310-cp310-macosx_14_0_arm64.whl
- Upload date:
- Size: 303.1 kB
- Tags: CPython 3.10, macOS 14.0+ ARM64
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7d35a62533321157e78d742f0b17689c537eb385cf3c12524b8fa587eac0d291
|
|
| MD5 |
f8094ad060efb0042e5da888bb5f212c
|
|
| BLAKE2b-256 |
d417244e1f433bf48dee1a55323d558d39d866e6fec304ea65824f8af2fd58ad
|
Provenance
The following attestation bundles were made for nrgboost-0.0.3-cp310-cp310-macosx_14_0_arm64.whl:
Publisher:
publish.yaml on Ajoo/nrgboost
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
nrgboost-0.0.3-cp310-cp310-macosx_14_0_arm64.whl -
Subject digest:
7d35a62533321157e78d742f0b17689c537eb385cf3c12524b8fa587eac0d291 - Sigstore transparency entry: 2014020113
- Sigstore integration time:
-
Permalink:
Ajoo/nrgboost@feef73a3edb20b911c2f7214b13f810909ef20ad -
Branch / Tag:
refs/tags/v0.0.3 - Owner: https://github.com/Ajoo
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yaml@feef73a3edb20b911c2f7214b13f810909ef20ad -
Trigger Event:
push
-
Statement type: