Training Framework for Arbitrary Tabular Generative Adversarial Networks
Project description
AT-GAN
Arbitrary Tabular Generative Adversarial Network
A Tabular GAN framework for generating synthetic tabular data from arbitrary mixed-type tabular datasets.
Table of Contents
- Overview
- Key Features
- Installation
- CLI Usage
- API Usage
- Configuration Reference
- In-Training Evaluation Suite
- Synthetic Data Evaluation (Post-Training)
Overview
at-gan is a framework for training Generative Adversarial Networks on arbitrary tabular data. It is designed to work with continuous, binary, discrete count, and categorical features within a single pipeline.
The framework combines a multi-branch generator (G), a PacGAN-style discriminator (D), an integrated evaluation suite, and Weights & Biases (W&B) sweep orchestration, experiment tracking, and training monitoring + visualization.
Goal: Training a GAN that is capable of producing realistic synthetic tabular data from a given dataset with minimal manual tuning and a transparent, observable training process.
Key Features
Dynamic, Config-Driven Architectures
- Generator and Discriminator built entirely from YAML-config.
- Configurable amount of
layersandunits. - Configurable activations:
relu,leaky_relu,elu, or any other activation supported in Keras. - Configurable
dropoutlayers. - Optional
Batch Normalizationfor G.
Mixed-Type Data Handling
- The
TabularPreprocessorhandles types of input features:- Continuous →
MinMaxScaler(-1, 1)→tanhoutput branch. - Discrete Count →
MinMaxScaler(0, 1)→sigmoidoutput branch. - Binary → 0/1 and optional β-distributed noise application →
sigmoidoutput branch. - Categorical → One-hot encoding and optional label-preserving smoothing →
softmaxoutput branch.
- Continuous →
- Per-column decimal precision preservation.
- Scalers and encoders are stored and reused for inference.
GAN Training and Stabilization Techniques
| Technique | Controlled by | What it does |
|---|---|---|
| PacGAN packing | discriminator.pack_size |
Concatenates k rows into a single D input → fights mode collapse |
| One-sided label smoothing | discriminator.label_smoothing_min |
Real labels sampled from [min, 1.0] instead of hard 1.0 |
| Label flipping | discriminator.label_flipping |
Random fraction of real labels flipped to 0 to prevent D overconfidence |
| TTUR | g_lr / d_lr |
Different LRs for G and D. Sweeps auto-clamp d_lr ≤ g_lr |
| G:D update ratio | g_updates_per_epoch |
Multiple G steps per D step to balance the training process |
| LR Cosine decay + warm restarts | lr_cosine_decay |
CosineDecayRestarts schedule with configurable alpha floor for the learning rate of G and D |
Adam beta_1 override |
adam_beta_1 |
Typically lowered from default 0.9 for training stability |
| Gradient clipping | always-on | clipnorm=1.0 on both Adam optimizers |
In-Training Evaluation Suite
Runs every eval_frequency epochs on held-out real samples, logs results to W&B, and saves the best checkpoint by error score. See In-Training Evaluation Suite.
Experiment Tracking
Weights & Biases integration:
- Per-epoch loss/metric logging via a dedicated
WandbCallback. - Training visuals: correlation heatmaps + PCA overlap scatter plots.
- Local-only mode when
--no-wandbis set (usesrun_id="offline_run").
Sweeps & Neural Architecture Search
- W&B sweeps for Neural Architecture Search (NAS) and Hyperparameter Optimization.
- Mechanic to resume existing W&B sweeps (and single runs).
Synthetic Data Evaluation (Post-Training)
- Privacy: Distance to Closest Record (DCR)
- Statistic Fidelity: Synthetic Data Vault (SDV)
- Utility Retention: Train on Synthetic, Test on Real (TSTR)
Usage Modes
- 🖥️ CLI:
train,sweep,generate,evaluate. - 🐍 Python API (
at_gan.api):train,sweep,generate,evaluate.
Installation
Requirements: Python 3.10 – 3.12 and dependencies listed in pyproject.toml.
Option A: Install from PyPI (recommended)
pip install at-gan
Option B: Editable install from the GitHub Repository
- Clone the GitHub repository
- Run the following command:
pip install -e .
Verify installation
at-gan --help
python -c "import at_gan; print(at_gan.__version__)"
Weights & Biases Login (one-time)
wandb login
💡 You can use this framework without W&B by passing
--no-wandb(CLI) orenable_wandb=False(API).
CLI Usage
at-gan --help
train: Run or resume a single GAN training run
| Flag | Short | Default | Description |
|---|---|---|---|
--config |
-c |
required | Path to the YAML experiment config |
--wandb / --no-wandb |
-w / -nw |
--wandb |
Toggle W&B tracking |
--export / --no-export |
-e / -ne |
--export |
Save .keras generator file |
--generate-samples |
-g |
1000 |
Auto-generate N samples post-training |
Examples:
at-gan train -c configs/config.yaml -w -e -g 5000
Note: A run can be resumed via the resume_run_id config key. See Configuration Reference.
sweep: Run or resume a W&B sweep
| Flag | Short | Description |
|---|---|---|
--base-config |
-c |
Baseline experiment config |
--sweep-config |
-s |
W&B sweep config (required for new sweeps) |
--count |
-n |
Max runs this agent will execute |
--sweep-id |
-id |
Resume an existing sweep instead of creating one |
# Launch a new 50-run sweep
at-gan sweep -c configs/config.yaml -s configs/sweep_config.yaml -n 50
# Resume an existing sweep
at-gan sweep -c configs/config.yaml -id abc123 -n 20
generate: Generate synthetic samples from a trained generator
| Flag | Short | Description |
|---|---|---|
--config |
-c |
YAML used during the original training run |
--run-id |
-r |
W&B run ID or "offline_run" |
--samples |
-n |
Number of samples to generate |
--output |
-o |
Optional override for CSV output path |
at-gan generate -c configs/config.yaml -r a1b2c3 -n 10000 -o synthetic_data.csv
Note: generate always loads best_generator.keras, not the latest.
evaluate: Run synthetic data evaluation (post-training)
| Flag | Short | Description |
|---|---|---|
--real |
-r |
Path to the real data CSV |
--synthetic |
-s |
Path to the synthetic data CSV |
--target |
-t |
Optional target column for TSTR evaluation |
at-gan evaluate -c real_data.csv -r synthetic_data.csv -t target_column
API Usage
The Python API exposes the same primary functions as a CLI, making it easy to integrate into existing projects.
See examples/api_example.py and examples/api_example.ipynb in the GitHub Repository for a full API usage example.
Note: The
trainentry point also accepts adictinstead of a path to a YAML file as input.
Configuration Reference
Experiments are driven by two YAML files: a base config and a sweep config.
See configs/config.yaml and configs/sweep_config.yaml in the GitHub Repository for examples and recommended default values for most datasets.
Base Config Reference
# =============================================================
# EXPERIMENT META
# =============================================================
experiment_name: "test_experiment" # also output directory name
resume_run_id: null # W&B run id to resume from checkpoint (optional)
seed: 1130 # seeds Python, NumPy, TensorFlow
# =============================================================
# DATA
# =============================================================
data:
dataset_path: "datasets/example.csv"
output_path: "experiments/" # run artifacts found in 'output_path/experiment_name/run_id/'
# Column routing — every column the GAN should learn MUST be listed here
continuous_cols: ["age", "heart_rate", "glucose"]
binary_cols: ["male", "smoker"]
discrete_count_cols: ["cigs_per_day"]
categorical_cols: ["education"]
# Preprocessing toggles
treat_bin_as_cat: false # route binary cols through OHE + softmax
beta_noise: true # Apply Beta-distributed noise on binary cols
smooth_categorical: true # Apply label-preserving noise on OHE groups
# =============================================================
# MODEL
# =============================================================
model:
latent_dim: 32
generator:
units: [64, 64]
dropout: 0.0
activation: "relu" # relu | leaky_relu | elu | ...
batch_norm: true # BatchNorm after each Dense layer
# negative_slope: 0.2 # used only when activation == "leaky_relu"
discriminator:
units: [256, 256]
dropout: 0.2
activation: "leaky_relu"
negative_slope: 0.2
pack_size: 3 # PacGAN packing factor (1 disables packing)
label_smoothing_min: 0.9 # e.g. real labels ~ [0.9, 1.0]
label_flipping: 0.05 # e.g. 5% of real labels flipped to 0 each step
# =============================================================
# TRAINING
# =============================================================
training:
device: "cpu" # "cpu" or "gpu"
epochs: 2000
batch_size: 512
g_updates_per_epoch: 2 # G steps per D step
# Optimizers
adam_beta_1: 0.5 # GAN-stable Adam beta_1
g_lr: 0.0002 # G Learning Rate
d_lr: 0.0003 # D Learning Rate
# LR schedule
lr_cosine_decay: true
lr_cosine_decay_restart_epochs: 2000 # restart every N epochs
g_lr_decay_alpha: 0.1 # minimum G LR fraction (floor)
d_lr_decay_alpha: 0.1 # minimum D LR fraction (floor)
# Evaluation & checkpointing
checkpoint_frequency: 100 # save "latest" every N epochs
eval_frequency: 100 # run evaluation suite every N epochs
test_split_pct: 0.2 # percentage of data to hold out for in-training evaluation
Sweep Config Reference
# =============================================================
# SWEEP STRATEGY & METRICS
# =============================================================
method: bayes
metric:
name: Eval/Total_Error # W&B log key
goal: minimize
early_terminate:
type: hyperband # Kills unpromising runs early to save compute time
min_iter: 300 # Don't kill any run before e.g. epoch 300
eta: 3 # The halving rate for the Hyperband brackets
# =============================================================
# PARAMETERS
# =============================================================
parameters:
# Sweeps choose from a fixed set of hyperparameter values
model.latent_dim:
values: [ 16, 32, 64, 128, 256 ]
# -----------------------------------------------------------
# Generator Architecture
# -----------------------------------------------------------
generator.num_hidden_layers:
values: [ 2, 3, 4 ]
generator.base_units:
values: [ 32, 64, 128, 256, 512 ]
generator.max_units:
value: 512
generator.architecture_shape:
values: [ "block", "ascending", "descending" ]
generator.dropout:
value: 0.0 # e.g. Fixed to 0.0
generator.activation:
values: [ 'relu', 'leaky_relu' ]
generator.batch_norm:
values: [ true, false ]
# -----------------------------------------------------------
# Discriminator Architecture
# -----------------------------------------------------------
discriminator.num_hidden_layers:
values: [ 2, 3, 4 ]
discriminator.base_units:
values: [ 32, 64, 128, 256, 512 ]
discriminator.max_units:
value: 512
discriminator.architecture_shape:
values: [ "block", "ascending", "descending" ]
discriminator.dropout:
values: [ 0.0, 0.2, 0.3, 0.5 ]
discriminator.activation:
values: [ 'relu', 'leaky_relu' ]
discriminator.negative_slope:
values: [ 0.1, 0.2, 0.3 ]
discriminator.pack_size:
values: [ 1, 3 ]
discriminator.label_smoothing_min:
values: [ 0.85, 0.9, 0.95, 1.0 ]
discriminator.label_flipping:
values: [ 0.0, 0.05, 0.1 ]
# -----------------------------------------------------------
# Training Loop & Optimizers
# -----------------------------------------------------------
training.batch_size:
values: [ 64, 128, 256, 512 ]
training.g_updates_per_epoch:
values: [ 1, 2, 3 ]
training.adam_beta_1:
values: [ 0.2, 0.5, 0.7, 0.9 ]
# Learning Rates
training.g_lr:
distribution: log_uniform_values
min: 0.00001
max: 0.001
training.d_lr:
distribution: log_uniform_values
min: 0.000005
max: 0.0005 # at-gan ensures d_lr <= g_lr
# Cosine Decay Warm Restart Parameters
training.lr_cosine_decay_restart_epochs:
distribution: int_uniform
min: 100
max: 1000
training.g_lr_decay_alpha:
distribution: log_uniform_values
min: 0.01 # Decay to 1% of max LR
max: 1 # No decay
training.d_lr_decay_alpha:
distribution: log_uniform_values
min: 0.01
max: 1
In-Training Evaluation Suite
Every eval_frequency epochs, GANCallback generates synthetic samples and runs an evaluation against the held-out real samples to guide the hyperparameter sweep:
| Metric | Computation |
|---|---|
| PCA Error | First Wasserstein distance between real and synthetic data across the first five PCA components |
| Adversarial Error | Absolute AUC deviation of a Random Forest classifier trained to distinguish real and synthetic data (|AUC - 0.5| × 2) |
| Total Error | sqrt((pca_error² + adv_error²) / 2.0) |
Raw errors are passed through a squashing function (1 - exp(-x)) so components ∈ [0, 1].
Visual artifacts (auto-logged to W&B)
- Correlation heatmaps: real, synthetic, and absolute difference.
- PCA scatter overlay: first two principal components of real vs. synthetic.
Synthetic Data Evaluation (Post-Training)
The evaluate command runs a comprehensive benchmark suite that assesses the quality of the synthetic data generated by the GAN:
- Privacy (DCR): Distance to Closest Record. Measures the minimum Euclidean distance (in standard deviations) between synthetic rows and real training rows. Absence of exact memorization is guaranteed if
Min. DCR > 0. - Statistic Fidelity (SDV): Uses the Synthetic Data Vault (
sdmetricspackage) to generate a Quality Report, comparing 1D marginal distributions (Column Shapes) and 2D correlations (Column Pair Trends). - Utility Retention (TSTR): Train on Synthetic, Test on Real.
- Splits real data into
real_train(80%) andreal_test(20%). - Trains a TRTR baseline (
RandomForest,GradientBoosting,LogisticRegression) onreal_train→ baseline F1 onreal_test. - Trains TSTR models on the entire synthetic set → F1 on the same
real_test. - Reports
TSTR_Mean_F1 / TRTR_Mean_F1 × 100(F1-Score Retention in %).
- Splits real data into
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file at_gan-0.12.1.tar.gz.
File metadata
- Download URL: at_gan-0.12.1.tar.gz
- Upload date:
- Size: 45.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8303a47fea80032cbcdc14444f37508792ff56880b01ff4c8e3fab36bb3883ed
|
|
| MD5 |
bb7382ac2de275871c86ac26aa068c42
|
|
| BLAKE2b-256 |
1b203a4919abb24794ab6081db178b3ed7e532d642339cbd512cfb308b89dcd1
|
File details
Details for the file at_gan-0.12.1-py3-none-any.whl.
File metadata
- Download URL: at_gan-0.12.1-py3-none-any.whl
- Upload date:
- Size: 45.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
20fb16fd42902561f07c103c341316f14e70e1072339f5004a6f23aeea5059b7
|
|
| MD5 |
961294cd53a51a4b911e03725a37a28c
|
|
| BLAKE2b-256 |
b04ccc7bc4b1b9f11af317a9e78195fd3a79191b1de0ed1e47cce9367af07a7f
|