Diffusion and Flow-based Models with Multi-Output XGBoost
Project description
A Python library for training and sampling from diffusion and flow-based generative models using multi-output XGBoost ensembles, as described in "Scaling Up Diffusion and Flow-based XGBoost Models". This is the installable package version. For the research version for reducing results from the paper, see this repo.
Installation
Install via uv:
uv pip install forest_diffusion_mo
Or with pip:
pip install forest_diffusion_mo
Requirements: Python ≥3.10
Quick Start
Basic Usage
import numpy as np
from forest_diffusion_mo import ForestModel
# Create some sample data (shape: n_samples × n_features)
X = np.random.randn(100, 3).astype(np.float32) # XGB casts to float32 internally
# Initialize the model
model = ForestModel(
logdir='my_model_dir', # XGB ensembles are saved to disk in parallel during training
multi_output=True, # True for multi-output XGB ensembles, otherwise uses single-output ensembles
diffusion_type='vp', # 'vp' for variance preserving diffusion or 'flow' for flow matching
n_t=10 # number of diffusion/flow timesteps
)
# Preprocess the data (handles scaling and encoding), then train
X_proc = model.preprocess(X)
model.train(X_proc)
# Generate new samples
samples = model.generate(n=100)
print(samples.shape) # (100, 3)
Load Trained Model and Sample
# Models are saved to logdir automatically during training
loaded_model = ForestModel.load_model('my_model_dir')
samples = loaded_model.generate(n=1000)
ForestModel Parameters
The ForestModel class requires a logdir, and accepts the following optional parameters:
Generative Model Configuration
| Parameter | Type | Default | Description |
|---|---|---|---|
multi_output |
bool | True | Whether to use multi-output or single-output XGB ensembles. |
diffusion_type |
str | 'vp' |
'vp' for variance preserving diffusion, or 'flow' for flow matching. |
n_t |
int | 50 | Number of diffusion/flow timesteps. Higher values = slower training/generation but better quality samples. |
duplicate_K |
int | 100 | Number of noise augmentation samples per original sample during training. Higher = more coverage of training data but slower. |
xgb_hypers |
dict | {} |
XGBoost hyperparameters (e.g., {'max_depth': 7, 'n_estimators': 100}). See XGBoost documentation. |
scaler |
str | 'min_max' |
Scaling method. 'min_max' creates one scaler per class y, 'single_min_max' uses a single unified scale over all classes. |
eps |
float | 0.001 | Minimum noise level for the diffusion process. Prevents blow up at t=0 for vp diffusion. Should be set to eps=0.0 for flow. |
beta_min |
float | 0.1 | Minimum noise schedule parameter (vp only). |
beta_max |
float | 8.0 | Maximum noise schedule parameter (vp only). |
solver |
str | 'euler' |
SDE/ODE solver used during generation: 'euler', 'heun', or 'rk4'. Higher order = slower but potentially more accurate. |
seed |
int | 0 | Random seed for data preprocessing and diffusion. |
Data Encoding Information
| Parameter | Type | Default | Description |
|---|---|---|---|
cat_indexes |
list | [] |
List of column indices that are categorical (will be one-hot encoded). |
bin_indexes |
list | [] |
List of column indices that are binary. |
int_indexes |
list | [] |
List of column indices that are integer/ordinal. |
true_min_max_values |
list | None |
List of form [[min_x, min_y], [max_x, max_y]]. Pre-computed min/max values for each feature. Use if consistent preprocessing across datasets is required. |
Parallelism Configuration
| Parameter | Type | Default | Description |
|---|---|---|---|
n_jobs |
int | -1 |
Number of parallel jobs for training (-1 = all cores). |
backend |
str | 'loky' |
Joblib backend: 'loky', 'multiprocessing', or 'threading'. We recommend not changing this. |
n_batch |
int | -1 |
Number of batches for QuantileDMatrix construction using XGB data iterator (-1 = no batching). |
ForestModel.generate() Optional Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
n |
int | None |
Number of samples to generate. If None, generates the same number of samples as in the training set. |
n_t |
int | None |
Number of solver steps for generation. If None, uses the value n_t from ForestModel construction; Can not be greater than this value. |
label_y |
array-like | None |
List of labels for conditional generation. If the model was trained with labels (via preprocess(X, label_y=...)), label_y specifies which class each sample should belong to. If len(label_y) < n, the list is tiled to the matching lenght. If None, labels are sampled according to the class distribution in the training data. |
n_jobs |
int | -1 |
Number of parallel jobs for generation (-1 = all cores). |
seed |
int | self.seed + 1 |
Random seed for generation. Should differ from the training seed to avoid starting from noise seen during training. |
Example: Training and Generation with Label Conditioning
from sklearn.datasets import load_iris
from forest_diffusion_mo import ForestModel
# Load your data
my_data = load_iris()
X, y = my_data['data'], my_data['target']
print(X.shape) # (150, 4)
# Configure and train
model = ForestModel(
logdir='my_model_dir',
multi_output=True,
diffusion_type='flow',
eps=0.0, # `flow` does not blow up at t=0
cat_indexes=[], # Iris's four features are all floats
bin_indexes=[],
int_indexes=[],
)
# Provide datapoint labels during preprocessing
X_proc = model.preprocess(X, y)
model.train(X_proc)
# Generate synthetic data conditional on a list of labels
label_y = y # Using labels as they appear in the training set distribution
samples = model.generate(n=300, label_y=label_y) # Training labels will be repeated twice
print(samples.shape) # (300, 5), four features followed by the label
Example: Data with Heterogeneous Column Types
from sklearn.datasets import fetch_california_housing
from forest_diffusion_mo import ForestModel
# Load your data
my_data = fetch_california_housing()
X, y = my_data['data'], my_data['target']
cat_indexes=[]
bin_indexes=[]
int_indexes = [1, 4] # Housing has two integer-valued features, the rest are floats. Zero indexed.
# Configure and train
model = ForestModel(
logdir='my_model_dir',
multi_output=True,
diffusion_type='flow',
eps=0.0, # `flow` does not blow up at t=0
cat_indexes=cat_indexes,
bin_indexes=bin_indexes,
int_indexes=int_indexes,
)
X_proc = model.preprocess(X)
model.train(X_proc)
samples = model.generate(n=100)
print(samples.shape)
XGBoost Version Considerations
Since multi-output trees are an experimental feature in XGBoost, performance of this package can be unstable in certain versions, and GPU training is not fully supported. We leave some notes here for future developers.
XGBoost had errors in the loss computation for multi-output trees before 2.1.0. Do not use lower versions.
The xgboost package supports GPU and CPU training. In some xgboost versions the library allocates ~400 MiB of GPU memory upon initialization, even if CPU training is specified. Since we launch many XGBoost processes in parallel, this can lead to issues if GPU memory is fully consumed. Adding os.environ["CUDA_VISIBLE_DEVICES"] = "" before importing XGBoost prevents GPU use and thus avoids this issue.
However, since our package is currently designed for CPU only, we instead use the lightweight xgboost-cpu package which also avoids the above issue. When XGBoost fully supports multi-output trees, GPU training can be re-examined.
Citation
If you use this library in your research, please cite the associated papers:
@article{cresswell2024scaling,
title={Scaling Up Diffusion and Flow-based XGBoost Models},
author={Cresswell, Jesse C and Kim, Taewoo},
journal={arXiv:2408.16046},
year={2024}
}
License
This code is licensed under the MIT License, copyright by Layer 6 AI.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file forest_diffusion_mo-1.0.1.tar.gz.
File metadata
- Download URL: forest_diffusion_mo-1.0.1.tar.gz
- Upload date:
- Size: 92.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4d2c7e908252f4ca06d53507bf76b837a43f376a0514f5b05483d59205d8c30f
|
|
| MD5 |
b21b6d0fd29b552b6e5030480be7e128
|
|
| BLAKE2b-256 |
993ea538411760c3ad41315f5982fbf2ee0fc45b2b5ab520b36da9a3b1e94554
|
File details
Details for the file forest_diffusion_mo-1.0.1-py3-none-any.whl.
File metadata
- Download URL: forest_diffusion_mo-1.0.1-py3-none-any.whl
- Upload date:
- Size: 19.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d64dec66bd4a3dc01cad96f7d54071599c4ede7dc81afe286a83bdc952a0b4b9
|
|
| MD5 |
c31539a30aad5655aa3b5a94a91e0f15
|
|
| BLAKE2b-256 |
db5cfabd797c0ba375c7052dc110a072338c4f3e0a43bc8a70891e6f43f066de
|