Skip to main content

Training Framework for Arbitrary Tabular Generative Adversarial Networks

Project description

AT-GAN

Arbitrary Tabular Generative Adversarial Network

A Tabular GAN framework for generating synthetic tabular data from arbitrary mixed-type tabular datasets.

Python TensorFlow Keras W&B License PyPI


Table of Contents

  1. Overview
  2. Key Features
  3. Installation
  4. CLI Usage
  5. API Usage
  6. Configuration Reference
  7. In-Training Evaluation Suite
  8. Synthetic Data Evaluation (Post-Training)

Overview

at-gan is a framework for training Generative Adversarial Networks on arbitrary tabular data. It is designed to work with continuous, binary, discrete count, and categorical features within a single pipeline.

The framework combines a multi-branch generator (G), a PacGAN-style discriminator (D), an integrated evaluation suite, and Weights & Biases (W&B) sweep orchestration, experiment tracking, and training monitoring + visualization.

Goal: Training a GAN that is capable of producing realistic synthetic tabular data from a given dataset with minimal manual tuning and a transparent, observable training process.


Key Features

Dynamic, Config-Driven Architectures

  • Generator and Discriminator built entirely from YAML-config.
  • Configurable amount of layers and units.
  • Configurable activations: relu, leaky_relu, elu, or any other activation supported in Keras.
  • Configurable dropout layers.
  • Optional Batch Normalization for G.

Mixed-Type Data Handling

  • The TabularPreprocessor handles types of input features:
    • ContinuousMinMaxScaler(-1, 1)tanh output branch.
    • Discrete CountMinMaxScaler(0, 1)sigmoid output branch.
    • Binary → 0/1 and optional β-distributed noise application → sigmoid output branch.
    • Categorical → One-hot encoding and optional label-preserving smoothing → softmax output branch.
  • Per-column decimal precision preservation.
  • Scalers and encoders are stored and reused for inference.

GAN Training and Stabilization Techniques

Technique Controlled by What it does
PacGAN packing discriminator.pack_size Concatenates k rows into a single D input → fights mode collapse
One-sided label smoothing discriminator.label_smoothing_min Real labels sampled from [min, 1.0] instead of hard 1.0
Label flipping discriminator.label_flipping Random fraction of real labels flipped to 0 to prevent D overconfidence
TTUR g_lr / d_lr Different LRs for G and D. Sweeps auto-clamp d_lr ≤ g_lr
G:D update ratio g_updates_per_epoch Multiple G steps per D step to balance the training process
LR Cosine decay + warm restarts lr_cosine_decay CosineDecayRestarts schedule with configurable alpha floor for the learning rate of G and D
Adam beta_1 override adam_beta_1 Typically lowered from default 0.9 for training stability
Gradient clipping always-on clipnorm=1.0 on both Adam optimizers

In-Training Evaluation Suite

Runs every eval_frequency epochs on held-out real samples, logs results to W&B, and saves the best checkpoint by error score. See In-Training Evaluation Suite.

Experiment Tracking

Weights & Biases integration:

  • Per-epoch loss/metric logging via a dedicated WandbCallback.
  • Training visuals: correlation heatmaps + PCA overlap scatter plots.
  • Local-only mode when --no-wandb is set (uses run_id="offline_run").

Sweeps & Neural Architecture Search

  • W&B sweeps for Neural Architecture Search (NAS) and Hyperparameter Optimization.
  • Mechanic to resume existing W&B sweeps (and single runs).

Synthetic Data Evaluation (Post-Training)

  • Privacy: Distance to Closest Record (DCR)
  • Statistic Fidelity: Synthetic Data Vault (SDV)
  • Utility Retention: Train on Synthetic, Test on Real (TSTR)

Usage Modes

  • 🖥️ CLI: train, sweep, generate, evaluate.
  • 🐍 Python API (at_gan.api): train, sweep, generate, evaluate.

Installation

Requirements: Python 3.10 – 3.12 and dependencies listed in pyproject.toml.

Option A: Install from PyPI (recommended)

pip install at-gan

Option B: Editable install from the GitHub Repository

  1. Clone the GitHub repository
  2. Run the following command:
pip install -e .

Verify installation

at-gan --help
python -c "import at_gan; print(at_gan.__version__)"

Weights & Biases Login (one-time)

wandb login

💡 You can use this framework without W&B by passing --no-wandb (CLI) or enable_wandb=False (API).


CLI Usage

at-gan --help

train: Run or resume a single GAN training run

Flag Short Default Description
--config -c required Path to the YAML experiment config
--wandb / --no-wandb -w / -nw --wandb Toggle W&B tracking
--export / --no-export -e / -ne --export Save .keras generator file
--generate-samples -g 1000 Auto-generate N samples post-training

Examples:

at-gan train -c configs/config.yaml -w -e -g 5000

Note: A run can be resumed via the resume_run_id config key. See Configuration Reference.


sweep: Run or resume a W&B sweep

Flag Short Description
--base-config -c Baseline experiment config
--sweep-config -s W&B sweep config (required for new sweeps)
--count -n Max runs this agent will execute
--sweep-id -id Resume an existing sweep instead of creating one
# Launch a new 50-run sweep
at-gan sweep -c configs/config.yaml -s configs/sweep_config.yaml -n 50

# Resume an existing sweep
at-gan sweep -c configs/config.yaml -id abc123 -n 20

generate: Generate synthetic samples from a trained generator

Flag Short Description
--config -c YAML used during the original training run
--run-id -r W&B run ID or "offline_run"
--samples -n Number of samples to generate
--output -o Optional override for CSV output path
at-gan generate -c configs/config.yaml -r a1b2c3 -n 10000 -o synthetic_data.csv

Note: generate always loads best_generator.keras, not the latest.


evaluate: Run synthetic data evaluation (post-training)

Flag Short Description
--real -r Path to the real data CSV
--synthetic -s Path to the synthetic data CSV
--target -t Discrete target column for (optional) TSTR evaluation
at-gan evaluate -c real_data.csv -r synthetic_data.csv -t target_column

Note: The TSTR evaluation is only performed if a discrete feature (i.e. binary or categorical) is specified as the target column.


API Usage

The Python API exposes the same primary functions as a CLI, making it easy to integrate into existing projects.

See examples/api_example.py and examples/api_example.ipynb in the GitHub Repository for a full API usage example.

Note: The train entry point also accepts a dict instead of a path to a YAML file as input.


Configuration Reference

Experiments are driven by two YAML files: a base config and a sweep config.

See configs/config.yaml and configs/sweep_config.yaml in the GitHub Repository for examples and recommended default values for most datasets.

Base Config Reference

# =============================================================
#  EXPERIMENT META
# =============================================================
experiment_name: "test_experiment"   # also output directory name
resume_run_id:   null                # W&B run id to resume from checkpoint (optional)
seed:            1130                # seeds Python, NumPy, TensorFlow

# =============================================================
#  DATA
# =============================================================
data:
  dataset_path: "datasets/example.csv"
  output_path:  "experiments/"          # run artifacts found in 'output_path/experiment_name/run_id/'

  # Column routing — every column the GAN should learn MUST be listed here
  continuous_cols:     ["age", "heart_rate", "glucose"]
  binary_cols:         ["male", "smoker"]
  discrete_count_cols: ["cigs_per_day"]
  categorical_cols:    ["education"]

  # Preprocessing toggles
  treat_bin_as_cat:    false    # route binary cols through OHE + softmax
  beta_noise:          true     # Apply Beta-distributed noise on binary cols
  smooth_categorical:  true     # Apply label-preserving noise on OHE groups

# =============================================================
#  MODEL
# =============================================================
model:
  latent_dim: 32             

  generator:
    units:        [64, 64]     
    dropout:      0.0
    activation:   "relu"        # relu | leaky_relu | elu | ...
    batch_norm:   true          # BatchNorm after each Dense layer
    # negative_slope: 0.2       # used only when activation == "leaky_relu"

  discriminator:
    units:               [256, 256]
    dropout:             0.2
    activation:          "leaky_relu"
    negative_slope:      0.2
    pack_size:           3      # PacGAN packing factor (1 disables packing)
    label_smoothing_min: 0.9    # e.g. real labels ~ [0.9, 1.0]
    label_flipping:      0.05   # e.g. 5% of real labels flipped to 0 each step

# =============================================================
#  TRAINING
# =============================================================
training:
  device:               "cpu"   # "cpu" or "gpu"
  epochs:               2000
  batch_size:           512
  g_updates_per_epoch:  2       # G steps per D step

  # Optimizers
  adam_beta_1:          0.5     # GAN-stable Adam beta_1
  g_lr:                 0.0002  # G Learning Rate
  d_lr:                 0.0003  # D Learning Rate

  # LR schedule
  lr_cosine_decay:                 true
  lr_cosine_decay_restart_epochs:  2000   # restart every N epochs
  g_lr_decay_alpha:                0.1    # minimum G LR fraction (floor)
  d_lr_decay_alpha:                0.1    # minimum D LR fraction (floor)

  # Evaluation & checkpointing
  checkpoint_frequency: 100     # save "latest" every N epochs
  eval_frequency:       100     # run evaluation suite every N epochs
  test_split_pct:       0.2     # percentage of data to hold out for in-training evaluation

Sweep Config Reference

# =============================================================
#  SWEEP STRATEGY & METRICS
# =============================================================
method: bayes              

metric:
  name: Eval/Total_Error     # W&B log key
  goal: minimize

early_terminate:
  type: hyperband            # Kills unpromising runs early to save compute time
  min_iter: 300              # Don't kill any run before e.g. epoch 300
  eta: 3                     # The halving rate for the Hyperband brackets

# =============================================================
# PARAMETERS
# =============================================================
parameters:

  # Sweeps choose from a fixed set of hyperparameter values
  model.latent_dim:
    values: [ 16, 32, 64, 128, 256 ]

  # -----------------------------------------------------------
  #  Generator Architecture
  # -----------------------------------------------------------
  generator.num_hidden_layers:
    values: [ 2, 3, 4 ]
  generator.base_units:
    values: [ 32, 64, 128, 256, 512 ]
  generator.max_units:
    value: 512                         
  generator.architecture_shape:
    values: [ "block", "ascending", "descending" ] 
  
  generator.dropout:
    value: 0.0                                   # e.g. Fixed to 0.0
  generator.activation:
    values: [ 'relu', 'leaky_relu' ]
  generator.batch_norm:
    values: [ true, false ]

  # -----------------------------------------------------------
  #  Discriminator Architecture
  # -----------------------------------------------------------
  discriminator.num_hidden_layers:
    values: [ 2, 3, 4 ]
  discriminator.base_units:
    values: [ 32, 64, 128, 256, 512 ]
  discriminator.max_units:
    value: 512
  discriminator.architecture_shape:
    values: [ "block", "ascending", "descending" ]
  
  discriminator.dropout:
    values: [ 0.0, 0.2, 0.3, 0.5 ]              
  discriminator.activation:
    values: [ 'relu', 'leaky_relu' ]
  discriminator.negative_slope:
    values: [ 0.1, 0.2, 0.3 ]                    
  discriminator.pack_size:
    values: [ 1, 3 ]                             
  discriminator.label_smoothing_min:
    values: [ 0.85, 0.9, 0.95, 1.0 ]             
  discriminator.label_flipping:
    values: [ 0.0, 0.05, 0.1 ]                  

  # -----------------------------------------------------------
  #  Training Loop & Optimizers
  # -----------------------------------------------------------
  training.batch_size:
    values: [ 64, 128, 256, 512 ]
  training.g_updates_per_epoch:
    values: [ 1, 2, 3 ]                          
  training.adam_beta_1:
    values: [ 0.2, 0.5, 0.7, 0.9 ] 
  
  # Learning Rates
  training.g_lr:
    distribution: log_uniform_values
    min: 0.00001                                 
    max: 0.001                                  
  training.d_lr:
    distribution: log_uniform_values
    min: 0.000005                               
    max: 0.0005                      # at-gan ensures d_lr <= g_lr

  # Cosine Decay Warm Restart Parameters
  training.lr_cosine_decay_restart_epochs:
    distribution: int_uniform
    min: 100
    max: 1000
  training.g_lr_decay_alpha:
    distribution: log_uniform_values
    min: 0.01                                    # Decay to 1% of max LR
    max: 1                                       # No decay
  training.d_lr_decay_alpha:
    distribution: log_uniform_values
    min: 0.01
    max: 1

In-Training Evaluation Suite

Every eval_frequency epochs, GANCallback generates synthetic samples and runs an evaluation against the held-out real samples to guide the hyperparameter sweep:

Metric Computation
PCA Error First Wasserstein distance between real and synthetic data across the first five PCA components
Adversarial Error Absolute AUC deviation of a Random Forest classifier trained to distinguish real and synthetic data (|AUC - 0.5| × 2)
Total Error sqrt((pca_error² + adv_error²) / 2.0)

Raw errors are passed through a squashing function (1 - exp(-x)) so components ∈ [0, 1].

Visual artifacts (auto-logged to W&B)

  • Correlation heatmaps: real, synthetic, and absolute difference.
  • PCA scatter overlay: first two principal components of real vs. synthetic.

Synthetic Data Evaluation (Post-Training)

The evaluate command runs a comprehensive benchmark suite that assesses the quality of the synthetic data generated by the GAN:

  1. Privacy (DCR): Distance to Closest Record. Measures the minimum Euclidean distance (in standard deviations) between synthetic rows and real training rows. Absence of exact memorization is guaranteed if Min. DCR > 0.
  2. Statistic Fidelity (SDV): Uses the Synthetic Data Vault (sdmetrics package) to generate a Quality Report, comparing 1D marginal distributions (Column Shapes) and 2D correlations (Column Pair Trends).
  3. Utility Retention (TSTR): Train on Synthetic, Test on Real.
    • Splits real data into real_train (80%) and real_test (20%).
    • Trains a TRTR baseline (RandomForest, GradientBoosting, LogisticRegression) on real_train → baseline F1 on real_test.
    • Trains TSTR models on the entire synthetic set → F1 on the same real_test.
    • Reports TSTR_Mean_F1 / TRTR_Mean_F1 × 100 (F1-Score Retention in %).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

at_gan-0.12.4.tar.gz (45.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

at_gan-0.12.4-py3-none-any.whl (45.7 kB view details)

Uploaded Python 3

File details

Details for the file at_gan-0.12.4.tar.gz.

File metadata

  • Download URL: at_gan-0.12.4.tar.gz
  • Upload date:
  • Size: 45.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for at_gan-0.12.4.tar.gz
Algorithm Hash digest
SHA256 567d22248df0c58a9bdd309bafd481fa718e3d1eae2c6d63c880e330e19f4758
MD5 3cd69b567070b0dc80e05e3493bd1d07
BLAKE2b-256 ac8262251b3179174e81f67ab81cf508db2bc49a20db0c23e38417c862026cdd

See more details on using hashes here.

Provenance

The following attestation bundles were made for at_gan-0.12.4.tar.gz:

Publisher: release.yaml on Jns-M/at-gan

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file at_gan-0.12.4-py3-none-any.whl.

File metadata

  • Download URL: at_gan-0.12.4-py3-none-any.whl
  • Upload date:
  • Size: 45.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for at_gan-0.12.4-py3-none-any.whl
Algorithm Hash digest
SHA256 88465c4278ab2d856d4b5ed3105f9360cd540a63d281a833d1b06dc774027d25
MD5 40878ae41c6ac5682926d96c6e1972a2
BLAKE2b-256 f50402f99023b7dcdcf1a98ba3f23bf646fb30a5d8aeb3d8d75b21574b21359a

See more details on using hashes here.

Provenance

The following attestation bundles were made for at_gan-0.12.4-py3-none-any.whl:

Publisher: release.yaml on Jns-M/at-gan

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page