Package for Synthetic Data Generation using Distributional Learninig of VAE
Project description
DistVAE-Tabular
DistVAE is a novel approach to distributional learning in the VAE framework, focusing on accurately capturing the underlying distribution of the observed dataset through a nonparametric CDF estimation.
We utilize the continuous ranked probability score (CRPS), a strictly proper scoring rule, as the reconstruction loss while preserving the mathematical derivation of the lower bound of the data log-likelihood. Additionally, we introduce a synthetic data generation mechanism that effectively preserves differential privacy.
For a detailed method explanations, check our paper! (link)
1. Installation
Install using pip:
pip install distvae-tabular
2. Usage
from distvae_tabular import distvae
distvae.DistVAE # DistVAE model
distvae.generate_data # generate synthetic data
- See example.ipynb for detailed example and its results with
loan
dataset.- Link for download
loan
dataset: https://www.kaggle.com/datasets/teertha/personal-loan-modeling
- Link for download
Example
"""device setting"""
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
"""load dataset and specify column types"""
import pandas as pd
data = pd.read_csv('./loan.csv')
continuous_features = [
'Age',
'Experience',
'Income',
'CCAvg',
'Mortgage',
]
categorical_features = [
'Family',
'Personal Loan',
'Securities Account',
'CD Account',
'Online',
'CreditCard'
]
integer_features = [
'Age',
'Experience',
'Income',
'Mortgage'
]
"""DistVAE"""
from distvae_tabular import distvae
distvae = distvae.DistVAE(
data=data, # the observed tabular dataset
continuous_features=continuous_features, # the list of continuous columns of data
categorical_features=categorical_features, # the list of categorical columns of data
integer_features=integer_features, # the list of integer-type columns of data
seed=42, # seed for repeatable results
latent_dim=4, # the latent dimension size
beta=0.1, # scale parameter of asymmetric Laplace distribution
hidden_dim=128, # the number of nodes in MLP
epochs=5, # the number of epochs (for quick checking)
batch_size=256, # the batch size
lr=0.001, # learning rate
step=0.1, # interval size between knots
threshold=1e-8, # threshold for clipping alpha_tild (numerical stability)
device="cpu"
)
"""training"""
distvae.train()
"""generate synthetic data"""
syndata = distvae.generate_data(100)
syndata
"""generate synthetic data with Differential Privacy"""
syndata = distvae.generate_data(100, lambda_=0.1)
syndata
Citation
If you use this code or package, please cite our associated paper:
@article{an2024distributional,
title={Distributional learning of variational AutoEncoder: application to synthetic data generation},
author={An, Seunghwan and Jeon, Jong-June},
journal={Advances in Neural Information Processing Systems},
volume={36},
year={2024}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file distvae_tabular-0.1.5.tar.gz
.
File metadata
- Download URL: distvae_tabular-0.1.5.tar.gz
- Upload date:
- Size: 9.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 189ebdf714f22f6a1918102b97ea0f793ef19f72521c7e5f1296df3e8a1dd741 |
|
MD5 | 368caaf4a27ced2d84ac1fe3b6fdb4e8 |
|
BLAKE2b-256 | 54aec7587f180d561adec6942c49479aeb2615e3beac328f053e36edaaae5f36 |
File details
Details for the file distvae_tabular-0.1.5-py3-none-any.whl
.
File metadata
- Download URL: distvae_tabular-0.1.5-py3-none-any.whl
- Upload date:
- Size: 9.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e748e0a86f06ae9f82b8c434171c4a707e50b0652752f9bdc29b257e47ff745e |
|
MD5 | 9db56422592d5762acd1320cc90d58e8 |
|
BLAKE2b-256 | 42609521181eb955508716b8136a33113123b0b90815f4088ddac2a6512999b1 |