Unmasking trees for tabular data generation and imputation
Project description
unmasking-trees 😷➡️🥳 🌲🌲🌲
UnmaskingTrees is a method for tabular data generation and imputation. It's an order-agnostic autoregressive diffusion model, wherein a training dataset is constructed by incrementally masking features in random order. Per-feature gradient-boosted trees are then trained to unmask each feature. Read more about it in my blog post!
To better model conditional distributions which are multi-modal ("modal" as in "modes", not as in "modalities"), we by default discretize continuous features into n_bins
bins. You can customize this, via the quantize_cols
parameter in the fit
method. Provide a list of length n_dims
, with values in ('continuous', 'categorical', 'integer')
. Given categorical
it skips quantization of that feature; given integer
it only quantizes if the number of unique values > n_bins
.
Installation
Installation from PyPI
pip install utrees
Installation from source
After cloning this repo, install the dependencies on the command-line, then install utrees:
pip install -r requirements.txt
pip install -e .
pytest
Usage
Check out this notebook with the Two Moons example, or this one with the Iris dataset.
Synthetic data generation
You can fit utrees.UnmaskingTrees
the way you would an sklearn model, with the added option that you can call fit
with quantize_cols
, a list to specify which columns are continuous (and therefore need to be discretized). By default, all columns are assumed to contain continuous features.
import numpy as np
from sklearn.datasets import make_moons
from utrees import UnmaskingTrees
data, labels = make_moons((100, 100), shuffle=False, noise=0.1, random_state=123) # size (200, 2)
utree = UnmaskingTrees().fit(data)
Then, you can generate new data:
newdata = utree.generate(n_generate=123) # size (123, 2)
Missing data imputation
You can fit your UnmaskingTrees
model on data with missing elements, provided as np.nan
. You can then impute the missing values, potentially with multiple imputations per missing element. Given an array of (n_samples, n_dims)
, you will get back an array of size (n_impute, n_samples, n_dims)
, where the NaNs have been replaced while the others are unchanged.
data4impute = data.copy()
data4impute[:, 1] = np.nan
X=np.concatenate([data, data4impute], axis=0) # size (400, 2)
utree = UnmaskingTrees().fit(X)
imputeddata = utree.impute(n_impute=5) # size (5, 400, 2)
You can also provide a totally new dataset to be imputed, so the model performs imputation without retraining:
utree = UnmaskingTrees().fit(data)
imputeddata = utree.impute(n_impute=5, X=data4impute) # size (5, 200, 2)
Hyperparameters
- depth: Depth of balanced binary tree for recursively quantizing each feature.
- duplicate_K: Number of random masking orders per actual sample. The training dataset will be of size
(n_samples * n_dims * duplicate_K, n_dims)
. - xgboost_kwargs: dict to pass to XGBClassifier.
- strategy: how to quantize continuous features ('kdiquantile', 'quantile', 'uniform', or 'kmeans').
- random_state: controls randomness.
Citing this method
Please consider citing the UnmaskingTrees arXiv preprint. The bibtex is:
@article{mccarter2024unmasking,
title={Unmasking Trees for Tabular Data},
author={McCarter, Calvin},
journal={arXiv preprint arXiv:2407.05593},
year={2024}
}
Also, please consider citing ForestDiffusion (code and paper), which this work builds on.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file utrees-0.3.0.tar.gz
.
File metadata
- Download URL: utrees-0.3.0.tar.gz
- Upload date:
- Size: 15.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | be5336c77e3b45c39b455eb3238ec0798a26840b4fdb965734318a7f5a67c61a |
|
MD5 | a49005d4329f46b21ef4e92ab638bf4b |
|
BLAKE2b-256 | c0e9a0dfa5df4ea427b755862774e8513c73974452687d773bdff6177672c0c0 |
File details
Details for the file utrees-0.3.0-py3-none-any.whl
.
File metadata
- Download URL: utrees-0.3.0-py3-none-any.whl
- Upload date:
- Size: 14.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6ea0d6ea51bed95e164f98c188322b2f8e507d71c72144df5628e6125aafac09 |
|
MD5 | 36eed7f6423b3c11608abebaf94df1b6 |
|
BLAKE2b-256 | 9f401e54762f89ec2795fa224f58d75403938df39d7b1a81ffa7e8432cb244db |