Conditional GAN for Tabular Data
Project description
An open source project from Data to AI Lab at MIT.
CTGAN
Implementation of our NeurIPS paper Modeling Tabular data using Conditional GAN.
CTGAN is a GAN-based data synthesizer that can generate synthetic tabular data with high fidelity.
- Free software: MIT license
- Documentation: https://DAI-Lab.github.io/CTGAN
- Homepage: https://github.com/DAI-Lab/CTGAN
Overview
Based on previous work (TGAN) on synthetic data generation, we develop a new model called CTGAN. Several major differences make CTGAN outperform TGAN.
- Preprocessing: CTGAN uses more sophisticated Variational Gaussian Mixture Model to detect modes of continuous columns.
- Network structure: TGAN uses LSTM to generate synthetic data column by column. CTGAN uses Fully-connected networks which is more efficient.
- Features to prevent mode collapse: We design a conditional generator and resample the training data to prevent model collapse on discrete columns. We use WGANGP and PacGAN to stabilize the training of GAN.
Install
Requirements
CTGAN has been developed and tested on Python 3.5, 3.6 and 3.7
Install from PyPI
The recommended way to installing CTGAN is using pip:
pip install ctgan
This will pull and install the latest stable release from PyPI.
If you want to install from source or contribute to the project please read the Contributing Guide.
Data Format
CTGAN expects the input data to be a table given as either a numpy.ndarray
or a
pandas.DataFrame
object with two types of columns:
- Continuous Columns: Columns that contain numerical values and which can take any value.
- Discrete columns: Columns that only contain a finite number of possible values, wether these are string values or not.
This is an example of a table with 4 columns:
- A continuous column with float values
- A continuous column with integer values
- A discrete column with string values
- A discrete column with integer values
A | B | C | D | |
---|---|---|---|---|
0 | 0.1 | 100 | 'a' | 1 |
1 | -1.3 | 28 | 'b' | 2 |
2 | 0.3 | 14 | 'a' | 2 |
3 | 1.4 | 87 | 'a' | 3 |
4 | -0.1 | 69 | 'b' | 2 |
NOTE: CTGAN does not distinguish between float and integer columns, which means that it will sample float values in all cases. If integer values are required, the outputted float values must be rounded to integers in a later step, outside of CTGAN.
Python Quickstart
In this short tutorial we will guide you through a series of steps that will help you getting started with CTGAN.
1. Model the data
Step 1: Prepare your data
Before being able to use CTGAN you will need to prepare your data as specified above.
For this example, we will be loading some data using the ctgan.load_demo
function.
from ctgan import load_demo
data = load_demo()
This will download a copy of the Adult Census Dataset as a dataframe:
age | workclass | fnlwgt | ... | hours-per-week | native-country | income |
---|---|---|---|---|---|---|
39 | State-gov | 77516 | ... | 40 | United-States | <=50K |
50 | Self-emp-not-inc | 83311 | ... | 13 | United-States | <=50K |
38 | Private | 215646 | ... | 40 | United-States | <=50K |
53 | Private | 234721 | ... | 40 | United-States | <=50K |
28 | Private | 338409 | ... | 40 | Cuba | <=50K |
... | ... | ... | ... | ... | ... | ... |
Aside from the table itself, you will need to create a list with the names of the discrete variables.
For this example:
discrete_columns = [
'workclass',
'education',
'marital-status',
'occupation',
'relationship',
'race',
'sex',
'native-country',
'income'
]
Step 2: Fit CTGAN to your data
Once you have the data ready, you need to import and create an instance of the CTGANSynthesizer
class and fit it passing your data and the list of discrete columns.
from ctgan import CTGANSynthesizer
ctgan = CTGANSynthesizer()
ctgan.fit(data, discrete_columns)
This process is likely to take a long time to run.
If you want to make the process shorter, or longer, you can control the number of training epochs
that the model will be performing by adding it to the fit
call:
ctgan.fit(data, discrete_columns, epochs=5)
2. Generate synthetic data
Once the process has finished, all you need to do is call the sample
method of your
CTGANSynthesizer
instance indicating the number of rows that you want to generate.
samples = ctgan.sample(1000)
The output will be a table with the exact same format as the input and filled with the synthetic data generated by the model.
age | workclass | fnlwgt | ... | hours-per-week | native-country | income |
---|---|---|---|---|---|---|
26.3191 | Private | 124079 | ... | 40.1557 | United-States | <=50K |
39.8558 | Private | 133996 | ... | 40.2507 | United-States | <=50K |
38.2477 | Self-emp-inc | 135955 | ... | 40.1124 | Ecuador | <=50K |
29.6468 | Private | 3331.86 | ... | 27.012 | United-States | <=50K |
20.9853 | Private | 120637 | ... | 40.0238 | United-States | <=50K |
... | ... | ... | ... | ... | ... | ... |
NOTE: CTGAN does not distinguish between float and integer columns, which means that it will sample float values in all cases. If integer values are required, the outputted float values must be rounded to integers in a later step, outside of CTGAN.
Join our community
- If you would like to try more dataset examples, please have a look at the examples folder of the repository. Please contact us if you have a usage example that you would want to share with the community.
- If you want to contribute to the project code, please head to the Contributing Guide for more details about how to do it.
- If you have any doubts, feature requests or detect an error, please open an issue on github
- Also do not forget to check the project documentation site!
Citing TGAN
If you use CTGAN, please cite the following work:
- Lei Xu, Maria Skoularidou, Alfredo Cuesta-Infante, Kalyan Veeramachaneni. Modeling Tabular data using Conditional GAN. NeurIPS, 2019.
@inproceedings{xu2019modeling,
title={Modeling Tabular data using Conditional GAN},
author={Xu, Lei and Skoularidou, Maria and Cuesta-Infante, Alfredo and Veeramachaneni, Kalyan},
booktitle={Advances in Neural Information Processing Systems},
year={2019}
}
History
v0.1.0 - 2019-11-07
First Release - NeurIPS 2019 Version.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file ctgan-0.2.0.dev0.tar.gz
.
File metadata
- Download URL: ctgan-0.2.0.dev0.tar.gz
- Upload date:
- Size: 60.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/42.0.2 requests-toolbelt/0.9.1 tqdm/4.40.0 CPython/3.6.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c26964d341ff40df5a57a2278086c4df005b7f7ac27fa5d6c6216c7e7a81dca4 |
|
MD5 | 012bae2f87a359c17dc6b7729e3b7692 |
|
BLAKE2b-256 | 1e0cd352b596faad250dd9caf20993b8a93a0b1363293fe316c5f74c61af86f7 |
Provenance
File details
Details for the file ctgan-0.2.0.dev0-py2.py3-none-any.whl
.
File metadata
- Download URL: ctgan-0.2.0.dev0-py2.py3-none-any.whl
- Upload date:
- Size: 15.0 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/42.0.2 requests-toolbelt/0.9.1 tqdm/4.40.0 CPython/3.6.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 54bbbd75d98cf7ab260dbb7abe8d87bb4921097b49a6b8d2766913ec79738977 |
|
MD5 | 8d4b966747ddea0a002472e7841b5b06 |
|
BLAKE2b-256 | c14b051adee99337bdd9aeafd934ade251ad34196ef2e5be2fb2bfa5425e4827 |