Skip to main content

Diffusion and Flow-based Models with Multi-Output XGBoost

Project description

Diffusion and Flow-based Models with Multi-Output XGBoost

arxiv pypi

A Python library for training and sampling from diffusion and flow-based generative models using multi-output XGBoost ensembles, as described in "Scaling Up Diffusion and Flow-based XGBoost Models". This is the installable package version. For the research version for reducing results from the paper, see this repo.

Installation

Install via uv:

uv pip install forest_diffusion_mo

Or with pip:

pip install forest_diffusion_mo

Requirements: Python ≥3.10

Quick Start

Basic Usage

import numpy as np
from forest_diffusion_mo import ForestModel

# Create some sample data (shape: n_samples × n_features)
X = np.random.randn(100, 3).astype(np.float32) # XGB casts to float32 internally
# Initialize the model
model = ForestModel(
    logdir='my_model_dir',  # XGB ensembles are saved to disk in parallel during training
    multi_output=True,      # True for multi-output XGB ensembles, otherwise uses single-output ensembles
    diffusion_type='vp',    # 'vp' for variance preserving diffusion or 'flow' for flow matching
    n_t=10                  # number of diffusion/flow timesteps
)
# Preprocess the data (handles scaling and encoding), then train
X_proc = model.preprocess(X)
model.train(X_proc)
# Generate new samples
samples = model.generate(n=100)
print(samples.shape) # (100, 3)

Load Trained Model and Sample

# Models are saved to logdir automatically during training
loaded_model = ForestModel.load_model('my_model_dir')
samples = loaded_model.generate(n=1000)

ForestModel Parameters

The ForestModel class requires a logdir, and accepts the following optional parameters:

Generative Model Configuration

Parameter Type Default Description
multi_output bool True Whether to use multi-output or single-output XGB ensembles.
diffusion_type str 'vp' 'vp' for variance preserving diffusion, or 'flow' for flow matching.
n_t int 50 Number of diffusion/flow timesteps. Higher values = slower training/generation but better quality samples.
duplicate_K int 100 Number of noise augmentation samples per original sample during training. Higher = more coverage of training data but slower.
xgb_hypers dict {} XGBoost hyperparameters (e.g., {'max_depth': 7, 'n_estimators': 100}). See XGBoost documentation.
scaler str 'min_max' Scaling method. 'min_max' creates one scaler per class y, 'single_min_max' uses a single unified scale over all classes.
eps float 0.001 Minimum noise level for the diffusion process. Prevents blow up at t=0 for vp diffusion. Should be set to eps=0.0 for flow.
beta_min float 0.1 Minimum noise schedule parameter (vp only).
beta_max float 8.0 Maximum noise schedule parameter (vp only).
solver str 'euler' SDE/ODE solver used during generation: 'euler', 'heun', or 'rk4'. Higher order = slower but potentially more accurate.
seed int 0 Random seed for data preprocessing and diffusion.

Data Encoding Information

Parameter Type Default Description
cat_indexes list [] List of column indices that are categorical (will be one-hot encoded).
bin_indexes list [] List of column indices that are binary.
int_indexes list [] List of column indices that are integer/ordinal.
true_min_max_values list None List of form [[min_x, min_y], [max_x, max_y]]. Pre-computed min/max values for each feature. Use if consistent preprocessing across datasets is required.

Parallelism Configuration

Parameter Type Default Description
n_jobs int -1 Number of parallel jobs for training (-1 = all cores).
backend str 'loky' Joblib backend: 'loky', 'multiprocessing', or 'threading'. We recommend not changing this.
n_batch int -1 Number of batches for QuantileDMatrix construction using XGB data iterator (-1 = no batching).

ForestModel.generate() Optional Parameters

Parameter Type Default Description
n int None Number of samples to generate. If None, generates the same number of samples as in the training set.
n_t int None Number of solver steps for generation. If None, uses the value n_t from ForestModel construction; Can not be greater than this value.
label_y array-like None List of labels for conditional generation. If the model was trained with labels (via preprocess(X, label_y=...)), label_y specifies which class each sample should belong to. If len(label_y) < n, the list is tiled to the matching lenght. If None, labels are sampled according to the class distribution in the training data.
n_jobs int -1 Number of parallel jobs for generation (-1 = all cores).
seed int self.seed + 1 Random seed for generation. Should differ from the training seed to avoid starting from noise seen during training.

Example: Training and Generation with Label Conditioning

from sklearn.datasets import load_iris
from forest_diffusion_mo import ForestModel

# Load your data
my_data = load_iris()
X, y = my_data['data'], my_data['target']
print(X.shape) # (150, 4)

# Configure and train
model = ForestModel(
    logdir='my_model_dir',
    multi_output=True,
    diffusion_type='flow',
    eps=0.0, # `flow` does not blow up at t=0
    cat_indexes=[], # Iris's four features are all floats
    bin_indexes=[],
    int_indexes=[],
)

# Provide datapoint labels during preprocessing
X_proc = model.preprocess(X, y)
model.train(X_proc)

# Generate synthetic data conditional on a list of labels
label_y = y  # Using labels as they appear in the training set distribution
samples = model.generate(n=300, label_y=label_y) # Training labels will be repeated twice
print(samples.shape) # (300, 5), four features followed by the label

Example: Data with Heterogeneous Column Types

from sklearn.datasets import fetch_california_housing
from forest_diffusion_mo import ForestModel

# Load your data
my_data = fetch_california_housing()
X, y = my_data['data'], my_data['target']
cat_indexes=[]
bin_indexes=[]
int_indexes = [1, 4] # Housing has two integer-valued features, the rest are floats. Zero indexed.

# Configure and train
model = ForestModel(
    logdir='my_model_dir',
    multi_output=True,
    diffusion_type='flow',
    eps=0.0, # `flow` does not blow up at t=0
    cat_indexes=cat_indexes,
    bin_indexes=bin_indexes,
    int_indexes=int_indexes,
)

X_proc = model.preprocess(X)
model.train(X_proc)
samples = model.generate(n=100)
print(samples.shape)

XGBoost Version Considerations

Since multi-output trees are an experimental feature in XGBoost, performance of this package can be unstable in certain versions, and GPU training is not fully supported. We leave some notes here for future developers.

XGBoost had errors in the loss computation for multi-output trees before 2.1.0. Do not use lower versions.

The xgboost package supports GPU and CPU training. In some xgboost versions the library allocates ~400 MiB of GPU memory upon initialization, even if CPU training is specified. Since we launch many XGBoost processes in parallel, this can lead to issues if GPU memory is fully consumed. Adding os.environ["CUDA_VISIBLE_DEVICES"] = "" before importing XGBoost prevents GPU use and thus avoids this issue.

However, since our package is currently designed for CPU only, we instead use the lightweight xgboost-cpu package which also avoids the above issue. When XGBoost fully supports multi-output trees, GPU training can be re-examined.

Citation

If you use this library in your research, please cite the associated papers:

@article{cresswell2024scaling,
  title={Scaling Up Diffusion and Flow-based XGBoost Models},
  author={Cresswell, Jesse C and Kim, Taewoo},
  journal={arXiv:2408.16046},
  year={2024}
}

License

This code is licensed under the MIT License, copyright by Layer 6 AI.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

forest_diffusion_mo-1.0.1.tar.gz (92.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

forest_diffusion_mo-1.0.1-py3-none-any.whl (19.0 kB view details)

Uploaded Python 3

File details

Details for the file forest_diffusion_mo-1.0.1.tar.gz.

File metadata

  • Download URL: forest_diffusion_mo-1.0.1.tar.gz
  • Upload date:
  • Size: 92.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for forest_diffusion_mo-1.0.1.tar.gz
Algorithm Hash digest
SHA256 4d2c7e908252f4ca06d53507bf76b837a43f376a0514f5b05483d59205d8c30f
MD5 b21b6d0fd29b552b6e5030480be7e128
BLAKE2b-256 993ea538411760c3ad41315f5982fbf2ee0fc45b2b5ab520b36da9a3b1e94554

See more details on using hashes here.

File details

Details for the file forest_diffusion_mo-1.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for forest_diffusion_mo-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 d64dec66bd4a3dc01cad96f7d54071599c4ede7dc81afe286a83bdc952a0b4b9
MD5 c31539a30aad5655aa3b5a94a91e0f15
BLAKE2b-256 db5cfabd797c0ba375c7052dc110a072338c4f3e0a43bc8a70891e6f43f066de

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page