Skip to main content

Conditional GAN for Tabular Data

Project description

“sdv-dev” An open source project from Data to AI Lab at MIT.

Development Status PyPI Shield Travis CI Shield Downloads Coverage Status

CTGAN

Implementation of our NeurIPS paper Modeling Tabular data using Conditional GAN.

CTGAN is a GAN-based data synthesizer that can generate synthetic tabular data with high fidelity.

Overview

Based on previous work (TGAN) on synthetic data generation, we develop a new model called CTGAN. Several major differences make CTGAN outperform TGAN.

  • Preprocessing: CTGAN uses more sophisticated Variational Gaussian Mixture Model to detect modes of continuous columns.
  • Network structure: TGAN uses LSTM to generate synthetic data column by column. CTGAN uses Fully-connected networks which is more efficient.
  • Features to prevent mode collapse: We design a conditional generator and resample the training data to prevent model collapse on discrete columns. We use WGANGP and PacGAN to stabilize the training of GAN.

Install

Requirements

CTGAN has been developed and tested on Python 3.6, 3.7 and 3.8

Install from PyPI

The recommended way to installing CTGAN is using pip:

pip install ctgan

This will pull and install the latest stable release from PyPI.

If you want to install from source or contribute to the project please read the Contributing Guide.

Data Format

CTGAN expects the input data to be a table given as either a numpy.ndarray or a pandas.DataFrame object with two types of columns:

  • Continuous Columns: Columns that contain numerical values and which can take any value.
  • Discrete columns: Columns that only contain a finite number of possible values, wether these are string values or not.

This is an example of a table with 4 columns:

  • A continuous column with float values
  • A continuous column with integer values
  • A discrete column with string values
  • A discrete column with integer values
A B C D
0 0.1 100 'a' 1
1 -1.3 28 'b' 2
2 0.3 14 'a' 2
3 1.4 87 'a' 3
4 -0.1 69 'b' 2

NOTE: CTGAN does not distinguish between float and integer columns, which means that it will sample float values in all cases. If integer values are required, the outputted float values must be rounded to integers in a later step, outside of CTGAN.

Python Quickstart

In this short tutorial we will guide you through a series of steps that will help you getting started with CTGAN.

1. Model the data

Step 1: Prepare your data

Before being able to use CTGAN you will need to prepare your data as specified above.

For this example, we will be loading some data using the ctgan.load_demo function.

from ctgan import load_demo

data = load_demo()

This will download a copy of the Adult Census Dataset as a dataframe:

age workclass fnlwgt ... hours-per-week native-country income
39 State-gov 77516 ... 40 United-States <=50K
50 Self-emp-not-inc 83311 ... 13 United-States <=50K
38 Private 215646 ... 40 United-States <=50K
53 Private 234721 ... 40 United-States <=50K
28 Private 338409 ... 40 Cuba <=50K
... ... ... ... ... ... ...

Aside from the table itself, you will need to create a list with the names of the discrete variables.

For this example:

discrete_columns = [
    'workclass',
    'education',
    'marital-status',
    'occupation',
    'relationship',
    'race',
    'sex',
    'native-country',
    'income'
]

Step 2: Fit CTGAN to your data

Once you have the data ready, you need to import and create an instance of the CTGANSynthesizer class and fit it passing your data and the list of discrete columns.

from ctgan import CTGANSynthesizer

ctgan = CTGANSynthesizer()
ctgan.fit(data, discrete_columns)

This process is likely to take a long time to run. If you want to make the process shorter, or longer, you can control the number of training epochs that the model will be performing by adding it to the fit call:

ctgan.fit(data, discrete_columns, epochs=5)

2. Generate synthetic data

Once the process has finished, all you need to do is call the sample method of your CTGANSynthesizer instance indicating the number of rows that you want to generate.

samples = ctgan.sample(1000)

The output will be a table with the exact same format as the input and filled with the synthetic data generated by the model.

age workclass fnlwgt ... hours-per-week native-country income
26.3191 Private 124079 ... 40.1557 United-States <=50K
39.8558 Private 133996 ... 40.2507 United-States <=50K
38.2477 Self-emp-inc 135955 ... 40.1124 Ecuador <=50K
29.6468 Private 3331.86 ... 27.012 United-States <=50K
20.9853 Private 120637 ... 40.0238 United-States <=50K
... ... ... ... ... ... ...

3. Generate synthetic data conditioning on one column

In the CTGAN model, we have a conditional vector. By setting the conditional vector, we increase the probability of getting one value in one discrete column.

For example, the following code increase the probability of workclass = " Private".

samples = ctgan.sample(1000, 'workclass', ' Private')

Note that this code does not guarante workclass=" Private"

4. Save and load the synthesizer

To save a trained ctgan synthesizer, use

ctgan.save(path_to_a_folder)

To restore a saved synthesizer, use

ctgan = CTGANSynthesizer()
ctgan.fit(data, discrete_columns, epochs=0, load_path=path_to_a_folder)

Please make sure the saved model and the loaded model are for the same dataset.

Join our community

  1. If you would like to try more dataset examples, please have a look at the examples folder of the repository. Please contact us if you have a usage example that you would want to share with the community.
  2. If you want to contribute to the project code, please head to the Contributing Guide for more details about how to do it.
  3. If you have any doubts, feature requests or detect an error, please open an issue on github

Citing TGAN

If you use CTGAN, please cite the following work:

  • Lei Xu, Maria Skoularidou, Alfredo Cuesta-Infante, Kalyan Veeramachaneni. Modeling Tabular data using Conditional GAN. NeurIPS, 2019.
@inproceedings{xu2019modeling,
  title={Modeling Tabular data using Conditional GAN},
  author={Xu, Lei and Skoularidou, Maria and Cuesta-Infante, Alfredo and Veeramachaneni, Kalyan},
  booktitle={Advances in Neural Information Processing Systems},
  year={2019}
}

Related Projects

Please note that these libraries are external contributions and are not maintained nor supervised by the MIT DAI-Lab team.

R interface for CTGAN

A wrapper around CTGAN has been implemented by Kevin Kuo @kevinykuo, bringing the functionalities of CTGAN to R users.

More details can be found in the corresponding repository: https://github.com/kasaai/ctgan

CTGAN Server CLI

A package to easily deploy CTGAN onto a remote server. This package is developed by Timothy Pillow @oregonpillow.

More details can be found in the corresponding repository: https://github.com/oregonpillow/ctgan-server-cli

History

v0.2.2 - 2020-11-10

In this release we introduce several minor improvements to make CTGAN more versatile and propertly support new types of data, such as categorical NaN values, as well as conditional sampling and features to save and load models.

Additionally, the dependency ranges and python versions have been updated to support up to date runtimes.

Many thanks @fealho @leix28 @csala @oregonpillow and @lurosenb for working on making this release possible!

Improvements

  • Drop Python 3.5 support - Issue #79 by @fealho
  • Support NaN values in categorical variables - Issue #78 by @fealho
  • Sample synthetic data conditioning on a discrete column - Issue #69 by @leix28
  • Support recent versions of pandas - Issue #57 by @csala
  • Easy solution for restoring original dtypes - Issue #26 by @oregonpillow

Bugs fixed

  • Loss to nan - Issue #73 by @fealho
  • Swapped the sklearn utils testing import statement - Issue #53 by @lurosenb

v0.2.1 - 2020-01-27

Minor version including changes to ensure the logs are properly printed and the option to disable the log transformation to the discrete column frequencies.

Special thanks to @kevinykuo for the contributions!

Issues Resolved:

  • Option to sample from true data frequency instead of logged frequency - Issue #16 by @kevinykuo
  • Flush stdout buffer for epoch updates - Issue #14 by @kevinykuo

v0.2.0 - 2019-12-18

Reorganization of the project structure with a new Python API, new Command Line Interface and increased data format support.

Issues Resolved:

  • Reorganize the project structure - Issue #10 by @csala
  • Move epochs to the fit method - Issue #5 by @csala

v0.1.0 - 2019-11-07

First Release - NeurIPS 2019 Version.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ctgan-0.2.2.dev2.tar.gz (63.0 kB view details)

Uploaded Source

Built Distribution

ctgan-0.2.2.dev2-py2.py3-none-any.whl (17.8 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file ctgan-0.2.2.dev2.tar.gz.

File metadata

  • Download URL: ctgan-0.2.2.dev2.tar.gz
  • Upload date:
  • Size: 63.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.24.0 setuptools/49.1.3 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.7.9

File hashes

Hashes for ctgan-0.2.2.dev2.tar.gz
Algorithm Hash digest
SHA256 93eb1155ea359ea44df2ddbde206ff90db014ddbe01d311ada16eaf46d259d62
MD5 60b81868c4a5dc57fa95ec42c2dd82d6
BLAKE2b-256 48df9bd359d0ceecb2ef358b39c89115b70497187d9517aa73c4ba44de83a3c0

See more details on using hashes here.

Provenance

File details

Details for the file ctgan-0.2.2.dev2-py2.py3-none-any.whl.

File metadata

  • Download URL: ctgan-0.2.2.dev2-py2.py3-none-any.whl
  • Upload date:
  • Size: 17.8 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.24.0 setuptools/49.1.3 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.7.9

File hashes

Hashes for ctgan-0.2.2.dev2-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 6d562f365ba0dba7506babf5010acda11355b1281667403f2fc60a4d71a5b3bc
MD5 326bc819e4cc9333d74cb5036c1dea4f
BLAKE2b-256 c488d20a01975c7ec1ac3d3d16b1d39880b0afaa0282376a44b88167dabfed5c

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page