Skip to main content

Official NRGBoost implementation

Project description

🔋 NRGBoost: Energy-Based Generative Boosted Trees

Official implementation of the NRGBoost algorithm.

Github: https://github.com/ajoo/nrgboost

Installation

To install the latest version of the python package run:

pip install nrgboost

Note: Requires python 3.10+. Prebuilt wheels are now available for Linux and macOS. Windows is not supported for now.

Building from source

To install from a source distribution you need a C compiler with OpenMP. On macOS, install Homebrew's libomp first:

brew install libomp

NRGBoost Models

The following example shows how to train a NRGBoost model on the California Housing dataset:

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split

from nrgboost import Dataset, NRGBooster

# Get data
df, y = fetch_california_housing(return_X_y=True, as_frame=True)
df.insert(0, y.name, y)
df, test_df = train_test_split(df, test_size=0.2, random_state=123)
train_df, val_df = train_test_split(df, test_size=0.2, random_state=124)

# Create training set
train_ds = Dataset(train_df)

# Train model
params = {
    'num_trees': 200,
    'shrinkage': 0.15,
    'max_leaves': 256,
    'max_ratio_in_leaf': 2,
    'num_model_samples': 80_000,
    'p_refresh': 0.1,
    'num_chains': 16,
    'burn_in': 100,
    }
model = NRGBooster.fit(train_ds, params, seed=1984)

Note: If your dataset contains categorical variables, they should be cast to pandas Categorical dtype before calling the Dataset constructor. For example df[categorical_col] = df[categorical_col].astype("category").

Prediction

To use the trained model for prediction we can call the predict method. This allows the user to specify a column name for prediction. Unlike discriminative methods, NRGBoost can be used to predict any column in the data, not just a specific "target" column.

import numpy as np
from sklearn.metrics import r2_score

# Do "early stopping" first:
# find the best boosting round for prediction in validation
# with cumulative=True, predict will return an iterator 
# over predictions at different rounds
preds = model.predict(val_df, y.name, cumulative=True)
val_r2 = [r2_score(val_df[y.name], yh) for yh in preds]
best_round = np.argmax(val_r2) 

#%% Compute test R^2 using only the first `best_round` trees
test_preds = model.predict(test_df, y.name, num_rounds=best_round)
test_r2 = r2_score(test_df[y.name], test_preds)
print('Test R^2:', test_r2)

Note: For numerical columns, NRGBoost currently predicts the expected value according to it's learned distribution. In the future we plan to make this more flexible so that the user can select a different point estimate (e.g., median or another quantile) or have access to the full distribution.

For categorical columns, NRGBoost will output logits for each possible outcome. The prediction will be an array with shape (N, K) where N is the number of points and K the cardinality of the column. The orders of each logit are determined by the pandas codes for each possible value. By default, the output logits are already normalized so they can be converted to probabilities directly by exponentiation (i.e., there is no need to apply softmax since their partition function is already normalized to 1).

Sampling

To draw 500 samples from the model using only the first best_round trees run:

samples_df = model.sample(500, num_steps=100, num_rounds=best_round)

num_steps is the number of Gibbs sampling steps that are used to generate each individual sample. It allows the user to trade-off computation time (which scales linearly in num_steps) for bias in the generated samples.

Saving and Loading

To save a NRGBoost model run

model.save('filename')

The saved model can then be loaded via:

from nrgboost import NRGBooster

model = NRGBooster.load('filename')

Cite NRGBoost

You can cite NRGBoost as:

@inproceedings{bravo2025nrgboost,
    title={{NRGB}oost: Energy-Based Generative Boosted Trees},
    author={Jo{\~a}o Bravo},
    booktitle={The Thirteenth International Conference on Learning Representations},
    year={2025},
    url={https://openreview.net/forum?id=wQHyjIZ1SH}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nrgboost-0.0.3.tar.gz (57.5 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

nrgboost-0.0.3-cp313-cp313-manylinux_2_28_x86_64.whl (201.4 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.28+ x86-64

nrgboost-0.0.3-cp313-cp313-macosx_14_0_arm64.whl (303.2 kB view details)

Uploaded CPython 3.13macOS 14.0+ ARM64

nrgboost-0.0.3-cp312-cp312-manylinux_2_28_x86_64.whl (201.4 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.28+ x86-64

nrgboost-0.0.3-cp312-cp312-macosx_14_0_arm64.whl (303.2 kB view details)

Uploaded CPython 3.12macOS 14.0+ ARM64

nrgboost-0.0.3-cp311-cp311-manylinux_2_28_x86_64.whl (200.8 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.28+ x86-64

nrgboost-0.0.3-cp311-cp311-macosx_14_0_arm64.whl (303.1 kB view details)

Uploaded CPython 3.11macOS 14.0+ ARM64

nrgboost-0.0.3-cp310-cp310-manylinux_2_28_x86_64.whl (200.8 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.28+ x86-64

nrgboost-0.0.3-cp310-cp310-macosx_14_0_arm64.whl (303.1 kB view details)

Uploaded CPython 3.10macOS 14.0+ ARM64

File details

Details for the file nrgboost-0.0.3.tar.gz.

File metadata

  • Download URL: nrgboost-0.0.3.tar.gz
  • Upload date:
  • Size: 57.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for nrgboost-0.0.3.tar.gz
Algorithm Hash digest
SHA256 7b9e6a2a951755a75f34f1ec1185e82c4038938de6d126b046d46ce0624bbda0
MD5 9bc8b5a37f4de87a4304b03661378727
BLAKE2b-256 de3657e57957b5c66d9c86021bf3b26f6b3c6a9d9810af55997ee11b7f2603cb

See more details on using hashes here.

Provenance

The following attestation bundles were made for nrgboost-0.0.3.tar.gz:

Publisher: publish.yaml on Ajoo/nrgboost

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file nrgboost-0.0.3-cp313-cp313-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for nrgboost-0.0.3-cp313-cp313-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 6d520c6b2a847a44f038604c82ea548e181bdaf6278ed1e265e8acad7e3c4b17
MD5 b99630b0401c8ac0dcee4a275bc78cc1
BLAKE2b-256 413935cd2fa51a6746d762d7016917b8bcf05ba5c8b4ae026a5ee36cd0ea8ec1

See more details on using hashes here.

Provenance

The following attestation bundles were made for nrgboost-0.0.3-cp313-cp313-manylinux_2_28_x86_64.whl:

Publisher: publish.yaml on Ajoo/nrgboost

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file nrgboost-0.0.3-cp313-cp313-macosx_14_0_arm64.whl.

File metadata

File hashes

Hashes for nrgboost-0.0.3-cp313-cp313-macosx_14_0_arm64.whl
Algorithm Hash digest
SHA256 515a764bf9d8a3deab15d3ea330776c3508d29fa5746d1cda099d7881c46c906
MD5 67f5713555781b67b1f27a0282bcff42
BLAKE2b-256 eb582d163dee3dce4955ea769ed13fdf3583f7bb25db9393267f8fbdf383c471

See more details on using hashes here.

Provenance

The following attestation bundles were made for nrgboost-0.0.3-cp313-cp313-macosx_14_0_arm64.whl:

Publisher: publish.yaml on Ajoo/nrgboost

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file nrgboost-0.0.3-cp312-cp312-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for nrgboost-0.0.3-cp312-cp312-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 4f402d792e6ec6691f5b6204591c30d511835cb6d944fc98cc3851b883c4a0c2
MD5 6850d5ebefba838a177e2a43f03deecd
BLAKE2b-256 b5d7e316415f8cf725cb4d7c4e9550c90a31f92f6984cad9cb4e507e1d5c7160

See more details on using hashes here.

Provenance

The following attestation bundles were made for nrgboost-0.0.3-cp312-cp312-manylinux_2_28_x86_64.whl:

Publisher: publish.yaml on Ajoo/nrgboost

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file nrgboost-0.0.3-cp312-cp312-macosx_14_0_arm64.whl.

File metadata

File hashes

Hashes for nrgboost-0.0.3-cp312-cp312-macosx_14_0_arm64.whl
Algorithm Hash digest
SHA256 a195038a7a1b5932d4649cec1bd9c0967a69c7f00a60b278226ae9a947d21dc6
MD5 6dea93c8e85ff47024cc90ad5be26d4c
BLAKE2b-256 313e16ce52e4bc5aa2d6bbb0a43a34e90fa4c45eb329a0453944bac624cdaa79

See more details on using hashes here.

Provenance

The following attestation bundles were made for nrgboost-0.0.3-cp312-cp312-macosx_14_0_arm64.whl:

Publisher: publish.yaml on Ajoo/nrgboost

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file nrgboost-0.0.3-cp311-cp311-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for nrgboost-0.0.3-cp311-cp311-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 dfe30829ceaf2d0d0ec03eab1744838bed857d56919238e7243c9fb7f273e1fb
MD5 b852ac819d9ae897bc493763e02131ea
BLAKE2b-256 d02d1b8f3947211614089ca6ff6913a0f0b49cb6a206800425fc7e7558b0657a

See more details on using hashes here.

Provenance

The following attestation bundles were made for nrgboost-0.0.3-cp311-cp311-manylinux_2_28_x86_64.whl:

Publisher: publish.yaml on Ajoo/nrgboost

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file nrgboost-0.0.3-cp311-cp311-macosx_14_0_arm64.whl.

File metadata

File hashes

Hashes for nrgboost-0.0.3-cp311-cp311-macosx_14_0_arm64.whl
Algorithm Hash digest
SHA256 691c76e494378363435d55cca87e57be3e11baa6ef519ebb231341e42f5dd9c1
MD5 fdb4c96f3dd1a40eb253e1ceb1528f3b
BLAKE2b-256 d649b967c4149671d7228b1346d0a9e6437b5a321502cbf569ac6a5116dd6e01

See more details on using hashes here.

Provenance

The following attestation bundles were made for nrgboost-0.0.3-cp311-cp311-macosx_14_0_arm64.whl:

Publisher: publish.yaml on Ajoo/nrgboost

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file nrgboost-0.0.3-cp310-cp310-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for nrgboost-0.0.3-cp310-cp310-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 6a2a91cc51fb6cd96dc7617083e75d1171f8c12250c93981ddea0a9041c9933d
MD5 a2251fe9b77fbe58229e9bb461276764
BLAKE2b-256 cc2a83cb6b200ab1c5aa3962f21ebc02f0755fbaff374dd0bfaebb1c550aa181

See more details on using hashes here.

Provenance

The following attestation bundles were made for nrgboost-0.0.3-cp310-cp310-manylinux_2_28_x86_64.whl:

Publisher: publish.yaml on Ajoo/nrgboost

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file nrgboost-0.0.3-cp310-cp310-macosx_14_0_arm64.whl.

File metadata

File hashes

Hashes for nrgboost-0.0.3-cp310-cp310-macosx_14_0_arm64.whl
Algorithm Hash digest
SHA256 7d35a62533321157e78d742f0b17689c537eb385cf3c12524b8fa587eac0d291
MD5 f8094ad060efb0042e5da888bb5f212c
BLAKE2b-256 d417244e1f433bf48dee1a55323d558d39d866e6fec304ea65824f8af2fd58ad

See more details on using hashes here.

Provenance

The following attestation bundles were made for nrgboost-0.0.3-cp310-cp310-macosx_14_0_arm64.whl:

Publisher: publish.yaml on Ajoo/nrgboost

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page