Skip to main content

Training Framework for Arbitrary Tabular Generative Adversarial Networks

Project description

AT-GAN

Arbitrary Tabular Generative Adversarial Network

A Tabular GAN framework for generating synthetic tabular data from arbitrary mixed-type tabular datasets.

Python TensorFlow Keras W&B License PyPI


Table of Contents

  1. Overview
  2. Key Features
  3. Installation
  4. CLI Usage
  5. API Usage
  6. Configuration Reference
  7. In-Training Evaluation Suite
  8. Synthetic Data Evaluation (Post-Training)

Overview

at-gan is a framework for training Generative Adversarial Networks on arbitrary tabular data. It is designed to work with continuous, binary, discrete count, and categorical features within a single pipeline.

The framework combines a multi-branch generator (G), a PacGAN-style discriminator (D), an integrated evaluation suite, and Weights & Biases (W&B) sweep orchestration, experiment tracking, and training monitoring + visualization.

Goal: Training a GAN that is capable of producing realistic synthetic tabular data from a given dataset with minimal manual tuning and a transparent, observable training process.


Key Features

Dynamic, Config-Driven Architectures

  • Generator and Discriminator built entirely from YAML-config.
  • Configurable amount of layers and units.
  • Configurable activations: relu, leaky_relu, elu, or any other activation supported in Keras.
  • Configurable dropout layers.
  • Optional Batch Normalization for G.

Mixed-Type Data Handling

  • The TabularPreprocessor handles types of input features:
    • ContinuousMinMaxScaler(-1, 1)tanh output branch.
    • Discrete CountMinMaxScaler(0, 1)sigmoid output branch.
    • Binary → 0/1 and optional β-distributed noise application → sigmoid output branch.
    • Categorical → One-hot encoding and optional label-preserving smoothing → softmax output branch.
  • Per-column decimal precision preservation.
  • Scalers and encoders are stored and reused for inference.

GAN Training and Stabilization Techniques

Technique Controlled by What it does
PacGAN packing discriminator.pack_size Concatenates k rows into a single D input → fights mode collapse
One-sided label smoothing discriminator.label_smoothing_min Real labels sampled from [min, 1.0] instead of hard 1.0
Label flipping discriminator.label_flipping Random fraction of real labels flipped to 0 to prevent D overconfidence
TTUR g_lr / d_lr Different LRs for G and D. Sweeps auto-clamp d_lr ≤ g_lr
G:D update ratio g_updates_per_epoch Multiple G steps per D step to balance the training process
LR Cosine decay + warm restarts lr_cosine_decay CosineDecayRestarts schedule with configurable alpha floor for the learning rate of G and D
Adam beta_1 override adam_beta_1 Typically lowered from default 0.9 for training stability
Gradient clipping always-on clipnorm=1.0 on both Adam optimizers

In-Training Evaluation Suite

Runs every eval_frequency epochs on held-out real samples, logs results to W&B, and saves the best checkpoint by error score. See In-Training Evaluation Suite.

Experiment Tracking

Weights & Biases integration:

  • Per-epoch loss/metric logging via a dedicated WandbCallback.
  • Training visuals: correlation heatmaps + PCA overlap scatter plots.
  • Local-only mode when --no-wandb is set (uses run_id="offline_run").

Sweeps & Neural Architecture Search

  • W&B sweeps for Neural Architecture Search (NAS) and Hyperparameter Optimization.
  • Mechanic to resume existing W&B sweeps (and single runs).

Synthetic Data Evaluation (Post-Training)

  • Privacy: Distance to Closest Record (DCR)
  • Statistic Fidelity: Synthetic Data Vault (SDV)
  • Utility Retention: Train on Synthetic, Test on Real (TSTR)

Usage Modes

  • 🖥️ CLI: train, sweep, generate, evaluate.
  • 🐍 Python API (at_gan.api): train, sweep, generate, evaluate.

Installation

Requirements: Python 3.10 – 3.12 and dependencies listed in pyproject.toml.

Option A: Install from PyPI (recommended)

pip install at-gan

Option B: Editable install from the GitHub Repository

  1. Clone the GitHub repository
  2. Run the following command:
pip install -e .

Verify installation

at-gan --help
python -c "import at_gan; print(at_gan.__version__)"

Weights & Biases Login (one-time)

wandb login

💡 You can use this framework without W&B by passing --no-wandb (CLI) or enable_wandb=False (API).


CLI Usage

at-gan --help

train: Run or resume a single GAN training run

Flag Short Default Description
--config -c required Path to the YAML experiment config
--wandb / --no-wandb -w / -nw --wandb Toggle W&B tracking
--export / --no-export -e / -ne --export Save .keras generator file
--generate-samples -g 1000 Auto-generate N samples post-training

Examples:

at-gan train -c configs/config.yaml -w -e -g 5000

Note: A run can be resumed via the resume_run_id config key. See Configuration Reference.


sweep: Run or resume a W&B sweep

Flag Short Description
--base-config -c Baseline experiment config
--sweep-config -s W&B sweep config (required for new sweeps)
--count -n Max runs this agent will execute
--sweep-id -id Resume an existing sweep instead of creating one
# Launch a new 50-run sweep
at-gan sweep -c configs/config.yaml -s configs/sweep_config.yaml -n 50

# Resume an existing sweep
at-gan sweep -c configs/config.yaml -id abc123 -n 20

generate: Generate synthetic samples from a trained generator

Flag Short Description
--config -c YAML used during the original training run
--run-id -r W&B run ID or "offline_run"
--samples -n Number of samples to generate
--output -o Optional override for CSV output path
at-gan generate -c configs/config.yaml -r a1b2c3 -n 10000 -o synthetic_data.csv

Note: generate always loads best_generator.keras, not the latest.


evaluate: Run synthetic data evaluation (post-training)

Flag Short Description
--real -r Path to the real data CSV
--synthetic -s Path to the synthetic data CSV
--target -t Optional target column for TSTR evaluation
at-gan evaluate -c real_data.csv -r synthetic_data.csv -t target_column

API Usage

The Python API exposes the same primary functions as a CLI, making it easy to integrate into existing projects.

See examples/api_example.py and examples/api_example.ipynb in the GitHub Repository for a full API usage example.

Note: The train entry point also accepts a dict instead of a path to a YAML file as input.


Configuration Reference

Experiments are driven by two YAML files: a base config and a sweep config.

See configs/config.yaml and configs/sweep_config.yaml in the GitHub Repository for examples and recommended default values for most datasets.

Base Config Reference

# =============================================================
#  EXPERIMENT META
# =============================================================
experiment_name: "test_experiment"   # also output directory name
resume_run_id:   null                # W&B run id to resume from checkpoint (optional)
seed:            1130                # seeds Python, NumPy, TensorFlow

# =============================================================
#  DATA
# =============================================================
data:
  dataset_path: "datasets/example.csv"
  output_path:  "experiments/"          # run artifacts found in 'output_path/experiment_name/run_id/'

  # Column routing — every column the GAN should learn MUST be listed here
  continuous_cols:     ["age", "heart_rate", "glucose"]
  binary_cols:         ["male", "smoker"]
  discrete_count_cols: ["cigs_per_day"]
  categorical_cols:    ["education"]

  # Preprocessing toggles
  treat_bin_as_cat:    false    # route binary cols through OHE + softmax
  beta_noise:          true     # Apply Beta-distributed noise on binary cols
  smooth_categorical:  true     # Apply label-preserving noise on OHE groups

# =============================================================
#  MODEL
# =============================================================
model:
  latent_dim: 32             

  generator:
    units:        [64, 64]     
    dropout:      0.0
    activation:   "relu"        # relu | leaky_relu | elu | ...
    batch_norm:   true          # BatchNorm after each Dense layer
    # negative_slope: 0.2       # used only when activation == "leaky_relu"

  discriminator:
    units:               [256, 256]
    dropout:             0.2
    activation:          "leaky_relu"
    negative_slope:      0.2
    pack_size:           3      # PacGAN packing factor (1 disables packing)
    label_smoothing_min: 0.9    # e.g. real labels ~ [0.9, 1.0]
    label_flipping:      0.05   # e.g. 5% of real labels flipped to 0 each step

# =============================================================
#  TRAINING
# =============================================================
training:
  device:               "cpu"   # "cpu" or "gpu"
  epochs:               2000
  batch_size:           512
  g_updates_per_epoch:  2       # G steps per D step

  # Optimizers
  adam_beta_1:          0.5     # GAN-stable Adam beta_1
  g_lr:                 0.0002  # G Learning Rate
  d_lr:                 0.0003  # D Learning Rate

  # LR schedule
  lr_cosine_decay:                 true
  lr_cosine_decay_restart_epochs:  2000   # restart every N epochs
  g_lr_decay_alpha:                0.1    # minimum G LR fraction (floor)
  d_lr_decay_alpha:                0.1    # minimum D LR fraction (floor)

  # Evaluation & checkpointing
  checkpoint_frequency: 100     # save "latest" every N epochs
  eval_frequency:       100     # run evaluation suite every N epochs
  test_split_pct:       0.2     # percentage of data to hold out for in-training evaluation

Sweep Config Reference

# =============================================================
#  SWEEP STRATEGY & METRICS
# =============================================================
method: bayes              

metric:
  name: Eval/Total_Error     # W&B log key
  goal: minimize

early_terminate:
  type: hyperband            # Kills unpromising runs early to save compute time
  min_iter: 300              # Don't kill any run before e.g. epoch 300
  eta: 3                     # The halving rate for the Hyperband brackets

# =============================================================
# PARAMETERS
# =============================================================
parameters:

  # Sweeps choose from a fixed set of hyperparameter values
  model.latent_dim:
    values: [ 16, 32, 64, 128, 256 ]

  # -----------------------------------------------------------
  #  Generator Architecture
  # -----------------------------------------------------------
  generator.num_hidden_layers:
    values: [ 2, 3, 4 ]
  generator.base_units:
    values: [ 32, 64, 128, 256, 512 ]
  generator.max_units:
    value: 512                         
  generator.architecture_shape:
    values: [ "block", "ascending", "descending" ] 
  
  generator.dropout:
    value: 0.0                                   # e.g. Fixed to 0.0
  generator.activation:
    values: [ 'relu', 'leaky_relu' ]
  generator.batch_norm:
    values: [ true, false ]

  # -----------------------------------------------------------
  #  Discriminator Architecture
  # -----------------------------------------------------------
  discriminator.num_hidden_layers:
    values: [ 2, 3, 4 ]
  discriminator.base_units:
    values: [ 32, 64, 128, 256, 512 ]
  discriminator.max_units:
    value: 512
  discriminator.architecture_shape:
    values: [ "block", "ascending", "descending" ]
  
  discriminator.dropout:
    values: [ 0.0, 0.2, 0.3, 0.5 ]              
  discriminator.activation:
    values: [ 'relu', 'leaky_relu' ]
  discriminator.negative_slope:
    values: [ 0.1, 0.2, 0.3 ]                    
  discriminator.pack_size:
    values: [ 1, 3 ]                             
  discriminator.label_smoothing_min:
    values: [ 0.85, 0.9, 0.95, 1.0 ]             
  discriminator.label_flipping:
    values: [ 0.0, 0.05, 0.1 ]                  

  # -----------------------------------------------------------
  #  Training Loop & Optimizers
  # -----------------------------------------------------------
  training.batch_size:
    values: [ 64, 128, 256, 512 ]
  training.g_updates_per_epoch:
    values: [ 1, 2, 3 ]                          
  training.adam_beta_1:
    values: [ 0.2, 0.5, 0.7, 0.9 ] 
  
  # Learning Rates
  training.g_lr:
    distribution: log_uniform_values
    min: 0.00001                                 
    max: 0.001                                  
  training.d_lr:
    distribution: log_uniform_values
    min: 0.000005                               
    max: 0.0005                      # at-gan ensures d_lr <= g_lr

  # Cosine Decay Warm Restart Parameters
  training.lr_cosine_decay_restart_epochs:
    distribution: int_uniform
    min: 100
    max: 1000
  training.g_lr_decay_alpha:
    distribution: log_uniform_values
    min: 0.01                                    # Decay to 1% of max LR
    max: 1                                       # No decay
  training.d_lr_decay_alpha:
    distribution: log_uniform_values
    min: 0.01
    max: 1

In-Training Evaluation Suite

Every eval_frequency epochs, GANCallback generates synthetic samples and runs an evaluation against the held-out real samples to guide the hyperparameter sweep:

Metric Computation
PCA Error First Wasserstein distance between real and synthetic data across the first five PCA components
Adversarial Error Absolute AUC deviation of a Random Forest classifier trained to distinguish real and synthetic data (|AUC - 0.5| × 2)
Total Error sqrt((pca_error² + adv_error²) / 2.0)

Raw errors are passed through a squashing function (1 - exp(-x)) so components ∈ [0, 1].

Visual artifacts (auto-logged to W&B)

  • Correlation heatmaps: real, synthetic, and absolute difference.
  • PCA scatter overlay: first two principal components of real vs. synthetic.

Synthetic Data Evaluation (Post-Training)

The evaluate command runs a comprehensive benchmark suite that assesses the quality of the synthetic data generated by the GAN:

  1. Privacy (DCR): Distance to Closest Record. Measures the minimum Euclidean distance (in standard deviations) between synthetic rows and real training rows. Absence of exact memorization is guaranteed if Min. DCR > 0.
  2. Statistic Fidelity (SDV): Uses the Synthetic Data Vault (sdmetrics package) to generate a Quality Report, comparing 1D marginal distributions (Column Shapes) and 2D correlations (Column Pair Trends).
  3. Utility Retention (TSTR): Train on Synthetic, Test on Real.
    • Splits real data into real_train (80%) and real_test (20%).
    • Trains a TRTR baseline (RandomForest, GradientBoosting, LogisticRegression) on real_train → baseline F1 on real_test.
    • Trains TSTR models on the entire synthetic set → F1 on the same real_test.
    • Reports TSTR_Mean_F1 / TRTR_Mean_F1 × 100 (F1-Score Retention in %).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

at_gan-0.12.3.tar.gz (45.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

at_gan-0.12.3-py3-none-any.whl (45.6 kB view details)

Uploaded Python 3

File details

Details for the file at_gan-0.12.3.tar.gz.

File metadata

  • Download URL: at_gan-0.12.3.tar.gz
  • Upload date:
  • Size: 45.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for at_gan-0.12.3.tar.gz
Algorithm Hash digest
SHA256 9c95ea7d3abaf0e34a051f8890f4fbcc37174a4ecd73cfa289f7956a089f6ac4
MD5 30c254926fee5a026e75af21b28a6b3b
BLAKE2b-256 0a0dd928e15298c88447a18ab56386fc2e66ac94fbacd58794dd4f00ac23b94d

See more details on using hashes here.

Provenance

The following attestation bundles were made for at_gan-0.12.3.tar.gz:

Publisher: release.yaml on Jns-M/at-gan

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file at_gan-0.12.3-py3-none-any.whl.

File metadata

  • Download URL: at_gan-0.12.3-py3-none-any.whl
  • Upload date:
  • Size: 45.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for at_gan-0.12.3-py3-none-any.whl
Algorithm Hash digest
SHA256 2487b060fa7d008a08a01187a1c6be5bd83f7b557bc81aa00896f5c0f0764d94
MD5 932c8508cbc3a4e5d13151a55958d1ef
BLAKE2b-256 3bc4d055333e0e3bd610dfcdc4a0db4d837b421ab93a0679e3bb19ef8a519012

See more details on using hashes here.

Provenance

The following attestation bundles were made for at_gan-0.12.3-py3-none-any.whl:

Publisher: release.yaml on Jns-M/at-gan

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page