Skip to main content

Probabilistic predictions for tabular data, using diffusion models and decision trees.

Project description

Treeffuser

PyPI version License: MIT GitHub Stars PyPI - Downloads Website Documentation arXiv

Treeffuser is an easy-to-use package for probabilistic prediction on tabular data with tree-based diffusion models. It estimates distributions of the form p(y|x) where x is a feature vector and y is a target vector. Treeffuser can model conditional distributions p(y|x) that are arbitrarily complex (e.g., multimodal, heteroscedastic, non-Gaussian, heavy-tailed, etc.).

It is designed to adhere closely to the scikit-learn API and require minimal user tuning.

Website | GitHub | Documentation | Paper (NeurIPS 2024)

Installation

Install Treeffuser from PyPI:

pip install treeffuser

Install the development version:

pip install git+https://github.com/blei-lab/treeffuser.git@main

The GitHub repository is located at: https://github.com/blei-lab/treeffuser

Usage Example

Here's a simple example demonstrating how to use Treeffuser.

We generate a heteroscedastic response with two sinusoidal components and heavy tails.

import matplotlib.pyplot as plt
import numpy as np
from treeffuser import Treeffuser, Samples

# Generate data
seed = 0
rng = np.random.default_rng(seed=seed)
n = 5000
x = rng.uniform(0, 2 * np.pi, size=n)
z = rng.integers(0, 2, size=n)
y = z * np.sin(x - np.pi / 2) + (1 - z) * np.cos(x) + rng.laplace(scale=x / 30, size=n)

We fit Treeffuser and generate samples. We then plot the samples against the raw data.

# Fit the model
model = Treeffuser(seed=seed)
model.fit(x, y)

# Generate and plot samples
y_samples = model.sample(x, n_samples=1, seed=seed, verbose=True)
plt.scatter(x, y, s=1, label="observed data")
plt.scatter(x, y_samples[0, :], s=1, alpha=0.7, label="Treeffuser samples")

Treeffuser on heteroscedastic data

Treeffuser accurately learns the target conditional densities and can generate samples from them.

These samples can be used to compute any downstream estimates of interest:

y_samples = model.sample(x, n_samples=100, verbose=True)  # y_samples.shape[0] is 100

# Estimate downstream quantities of interest
y_mean = y_samples.mean(axis=0)  # conditional mean
y_std = y_samples.std(axis=0)    # conditional std

You can also use the Samples helper class:

y_samples = Samples(y_samples)
y_mean = y_samples.sample_mean()
y_std = y_samples.sample_std()
y_quantiles = y_samples.sample_quantile(q=[0.05, 0.95])

See the documentation for more information on available methods and parameters.


Citing Treeffuser

If you use Treeffuser in your work, please cite:

@article{beltranvelez2024treeffuser,
  title={Treeffuser: Probabilistic Predictions via Conditional Diffusions with Gradient-Boosted Trees},
  author={Nicolas Beltran-Velez and Alessandro Antonio Grande and Achille Nazaret and Alp Kucukelbir and David Blei},
  year={2024},
  eprint={2406.07658},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2406.07658},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

treeffuser-0.2.0.tar.gz (24.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

treeffuser-0.2.0-py3-none-any.whl (32.4 kB view details)

Uploaded Python 3

File details

Details for the file treeffuser-0.2.0.tar.gz.

File metadata

  • Download URL: treeffuser-0.2.0.tar.gz
  • Upload date:
  • Size: 24.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.11.5 Darwin/24.0.0

File hashes

Hashes for treeffuser-0.2.0.tar.gz
Algorithm Hash digest
SHA256 2d07fb55e98eaaa26a8b305afb14a509675425d0492ed70f858f2a1a00a48947
MD5 9bd406febf8d4bd8a3bc417ad1d06c77
BLAKE2b-256 4547f9d7be4f331128cff1116eeb0b701af9b4798109b883037a09fe6f62dcc1

See more details on using hashes here.

File details

Details for the file treeffuser-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: treeffuser-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 32.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.11.5 Darwin/24.0.0

File hashes

Hashes for treeffuser-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 978a0a35aa56a38c5093dcaf122af7aa08f24aab2f8b0a2b384bb87a3a8b1eda
MD5 3361daa2bdd9491688e42ebe6c12e3b0
BLAKE2b-256 6ac9ffa15cbd9e38d1ffdf64241cd5fe76bc9445f0a8fa9d7286352eb3d34008

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page