Skip to main content

Package for Synthetic Data Generation using Distributional Learninig of VAE

Project description

DistVAE-Tabular

DistVAE is a novel approach to distributional learning in the VAE framework, focusing on accurately capturing the underlying distribution of the observed dataset through a nonparametric quantile estimation.

We utilize the continuous ranked probability score (CRPS), a strictly proper scoring rule, as the reconstruction loss while preserving the mathematical derivation of the lower bound of the data log-likelihood. Additionally, we introduce a synthetic data generation mechanism that effectively preserves differential privacy.

For a detailed method explanations, check our paper! (link)

1. Installation

Install using pip:

pip install distvae-tabular

2. Usage

from distvae_tabular import distvae
distvae.DistVAE # DistVAE model
distvae.generate_data # function for generating synthetic dataset

Example

"""device setting"""
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

"""load dataset and specify column types"""
import pandas as pd
data = pd.read_csv('./loan.csv') 
continuous_features = [
    'Age',
    'Experience',
    'Income', 
    'CCAvg',
    'Mortgage',
]
categorical_features = [
    'Family',
    'Personal Loan',
    'Securities Account',
    'CD Account',
    'Online',
    'CreditCard'
]
integer_features = [
    'Age',
    'Experience',
    'Income', 
    'Mortgage'
]

"""DistVAE"""
from distvae_tabular import distvae

distvae = distvae.DistVAE(
    data=data, # the observed tabular dataset
    continuous_features=continuous_features, # the list of continuous columns of data
    categorical_features=categorical_features, # the list of categorical columns of data
    integer_features=integer_features, # the list of integer-type columns of data
    
    seed=42, # seed for repeatable results
    latent_dim=4, # the latent dimension size
    beta=0.1, # scale parameter of asymmetric Laplace distribution
    hidden_dim=128, # the number of nodes in MLP
    
    epochs=5, # the number of epochs (for quick checking)
    batch_size=256, # the batch size
    lr=0.001, # learning rate
    
    step=0.1, # interval size between knots
    threshold=1e-8, # threshold for clipping alpha_tild (numerical stability)
    device="cpu"
)

"""training"""
distvae.train()

"""generate synthetic data"""
syndata = distvae.generate_data(100)
syndata

"""generate synthetic data with Differential Privacy"""
syndata = distvae.generate_data(100, lambda_=0.1)
syndata

3. Citation

If you use this code or package, please cite our associated paper:

@article{an2024distributional,
  title={Distributional learning of variational AutoEncoder: application to synthetic data generation},
  author={An, Seunghwan and Jeon, Jong-June},
  journal={Advances in Neural Information Processing Systems},
  volume={36},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

distvae_tabular-0.1.7.tar.gz (9.4 kB view details)

Uploaded Source

Built Distribution

distvae_tabular-0.1.7-py3-none-any.whl (9.2 kB view details)

Uploaded Python 3

File details

Details for the file distvae_tabular-0.1.7.tar.gz.

File metadata

  • Download URL: distvae_tabular-0.1.7.tar.gz
  • Upload date:
  • Size: 9.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for distvae_tabular-0.1.7.tar.gz
Algorithm Hash digest
SHA256 c43214abc35c3d8518ce14ca53cca62d061c060ef50b508a2f2cacd2f4a11814
MD5 d8baf82379b589b494c238770f73ba92
BLAKE2b-256 8701e95f843a9c7210405186eb8e2ceaa45d8a4b3897e7a4f9e5cb29f7b3d553

See more details on using hashes here.

File details

Details for the file distvae_tabular-0.1.7-py3-none-any.whl.

File metadata

File hashes

Hashes for distvae_tabular-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 a31eaa928a3d602a197917e99cfc569a3658e0d4e53667483cd202344fbac793
MD5 e9f4fde865d0602c2bfb01eb05874baa
BLAKE2b-256 182ea99d4e51fb5b5727c12d86ec1616779e8d28d76f58143f784b912fe89cad

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page