Skip to main content

Unmasking trees for tabular data generation and imputation

Project description

unmasking-trees 😷➡️🥳 🌲🌲🌲

PyPI version Downloads

UnmaskingTrees is a method for tabular data generation and imputation. It's an order-agnostic autoregressive diffusion model, wherein a training dataset is constructed by incrementally masking features in random order. Per-feature gradient-boosted trees are then trained to unmask each feature. Read more about it in my blog post!

To better model conditional distributions which are multi-modal ("modal" as in "modes", not as in "modalities"), we by default discretize continuous features into n_bins bins. You can customize this, via the quantize_cols parameter in the fit method. Provide a list of length n_dims, with values in ('continuous', 'categorical', 'integer'). Given categorical it skips quantization of that feature; given integer it only quantizes if the number of unique values > n_bins.

Here's how well it works on imputation with the [Two Moons](https://github.com/calvinmccarter/unmasking-trees/blob/master/paper/moons.ipynb) synthetic dataset:
drawing

Installation

Installation from PyPI

pip install utrees

Installation from source

After cloning this repo, install the dependencies on the command-line, then install utrees:

pip install -r requirements.txt
pip install -e .
pytest

Usage

Check out this notebook with the Two Moons example, or this one with the Iris dataset.

Synthetic data generation

You can fit utrees.UnmaskingTrees the way you would an sklearn model, with the added option that you can call fit with quantize_cols, a list to specify which columns are continuous (and therefore need to be discretized). By default, all columns are assumed to contain continuous features.

import numpy as np
from sklearn.datasets import make_moons
from utrees import UnmaskingTrees
data, labels = make_moons((100, 100), shuffle=False, noise=0.1, random_state=123)  # size (200, 2)
utree = UnmaskingTrees().fit(data)

Then, you can generate new data:

newdata = utree.generate(n_generate=123)  # size (123, 2)

Missing data imputation

You can fit your UnmaskingTrees model on data with missing elements, provided as np.nan. You can then impute the missing values, potentially with multiple imputations per missing element. Given an array of (n_samples, n_dims), you will get back an array of size (n_impute, n_samples, n_dims), where the NaNs have been replaced while the others are unchanged.

data4impute = data.copy()
data4impute[:, 1] = np.nan
X=np.concatenate([data, data4impute], axis=0)  # size (400, 2)
utree = UnmaskingTrees().fit(X)                                                                                    
imputeddata = utree.impute(n_impute=5)  # size (5, 400, 2)

You can also provide a totally new dataset to be imputed, so the model performs imputation without retraining:

utree = UnmaskingTrees().fit(data)                                                                                    
imputeddata = utree.impute(n_impute=5, X=data4impute)  # size (5, 200, 2)

Hyperparameters

  • depth: Depth of balanced binary tree for recursively quantizing each feature.
  • duplicate_K: Number of random masking orders per actual sample. The training dataset will be of size (n_samples * n_dims * duplicate_K, n_dims).
  • xgboost_kwargs: dict to pass to XGBClassifier.
  • strategy: how to quantize continuous features ('kdiquantile', 'quantile', 'uniform', or 'kmeans').
  • random_state: controls randomness.

Citing this method

Please consider citing the UnmaskingTrees arXiv preprint. The bibtex is:

@article{mccarter2024unmasking,
  title={Unmasking Trees for Tabular Data},
  author={McCarter, Calvin},
  journal={arXiv preprint arXiv:2407.05593},
  year={2024}
}

Also, please consider citing ForestDiffusion (code and paper), which this work builds on.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

utrees-0.3.0.tar.gz (15.1 kB view details)

Uploaded Source

Built Distribution

utrees-0.3.0-py3-none-any.whl (14.9 kB view details)

Uploaded Python 3

File details

Details for the file utrees-0.3.0.tar.gz.

File metadata

  • Download URL: utrees-0.3.0.tar.gz
  • Upload date:
  • Size: 15.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for utrees-0.3.0.tar.gz
Algorithm Hash digest
SHA256 be5336c77e3b45c39b455eb3238ec0798a26840b4fdb965734318a7f5a67c61a
MD5 a49005d4329f46b21ef4e92ab638bf4b
BLAKE2b-256 c0e9a0dfa5df4ea427b755862774e8513c73974452687d773bdff6177672c0c0

See more details on using hashes here.

File details

Details for the file utrees-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: utrees-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 14.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for utrees-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6ea0d6ea51bed95e164f98c188322b2f8e507d71c72144df5628e6125aafac09
MD5 36eed7f6423b3c11608abebaf94df1b6
BLAKE2b-256 9f401e54762f89ec2795fa224f58d75403938df39d7b1a81ffa7e8432cb244db

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page