Skip to main content

A novel method for generating tabular and relational data using language models.

Project description

Open In Colab Downloads Downloads Downloads

REaLTabFormer

The REaLTabFormer (Realistic Relational and Tabular Data using Transformers) offers a unified framework for synthesizing tabular data of different types. A sequence-to-sequence (Seq2Seq) model is used for generating synthetic relational datasets. The REaLTabFormer model for a non-relational tabular data uses GPT-2, and can be used out-of-the-box to model any tabular data with independent observations.

REaLTabFormer: Generating Realistic Relational and Tabular Data using Transformers
Paper on ArXiv


Installation

REaLTabFormer is available on PyPi and can be easily installed with pip (Python version >= 3.7):

pip install realtabformer

Usage

We show examples of using the REaLTabFormer for modeling and generating synthetic data from a trained model.

[!NOTE] The model implements an optimal stopping criterion based on the synthetic data distribution when training a non-relational tabular model. The model will stop training when the synthetic data distribution is close to the real data distribution.

Make sure to set the epochs parameter to a large number to allow the model to fit the data better. The model will stop training when the optimal stopping criterion is met.

REaLTabFormer for regular tabular data

# pip install realtabformer
import pandas as pd
from realtabformer import REaLTabFormer

df = pd.read_csv("foo.csv")

# NOTE: Remove any unique identifiers in the
# data that you don't want to be modeled.

# Non-relational or parent table.
rtf_model = REaLTabFormer(
    model_type="tabular",
    gradient_accumulation_steps=4,
    logging_steps=100)

# Fit the model on the dataset.
# Additional parameters can be
# passed to the `.fit` method.
rtf_model.fit(df)

# Save the model to the current directory.
# A new directory `rtf_model/` will be created.
# In it, a directory with the model's
# experiment id `idXXXX` will also be created
# where the artefacts of the model will be stored.
rtf_model.save("rtf_model/")

# Generate synthetic data with the same
# number of observations as the real dataset.
samples = rtf_model.sample(n_samples=len(df))

# Load the saved model. The directory to the
# experiment must be provided.
rtf_model2 = REaLTabFormer.load_from_dir(
    path="rtf_model/idXXXX")

REaLTabFormer for relational data

# pip install realtabformer
import os
import pandas as pd
from pathlib import Path
from realtabformer import REaLTabFormer

parent_df = pd.read_csv("foo.csv")
child_df = pd.read_csv("bar.csv")
join_on = "unique_id"

# Make sure that the key columns in both the
# parent and the child table have the same name.
assert ((join_on in parent_df.columns) and
        (join_on in child_df.columns))

# Non-relational or parent table. Don't include the
# unique_id field.
parent_model = REaLTabFormer(model_type="tabular")
parent_model.fit(parent_df.drop(join_on, axis=1))

pdir = Path("rtf_parent/")
parent_model.save(pdir)

# # Get the most recently saved parent model,
# # or a specify some other saved model.
# parent_model_path = pdir / "idXXX"
parent_model_path = sorted([
    p for p in pdir.glob("id*") if p.is_dir()],
    key=os.path.getmtime)[-1]

child_model = REaLTabFormer(
    model_type="relational",
    parent_realtabformer_path=parent_model_path,
    output_max_length=None,
    train_size=0.8)

child_model.fit(
    df=child_df,
    in_df=parent_df,
    join_on=join_on)

# Generate parent samples.
parent_samples = parent_model.sample(len(parend_df))

# Create the unique ids based on the index.
parent_samples.index.name = join_on
parent_samples = parent_samples.reset_index()

# Generate the relational observations.
child_samples = child_model.sample(
    input_unique_ids=parent_samples[join_on],
    input_df=parent_samples.drop(join_on, axis=1),
    gen_batch=64)

Validators for synthetic samples

The REaLTabFormer framework provides an interface to easily build observation validators for filtering invalid synthetic samples. We show an example of using the GeoValidator below. The chart on the left shows the distribution of the generated latitude and longitude without validation. The chart on the right shows the synthetic samples with observations that have been validated using the GeoValidator with the California boundary. Still, even when we did not optimally train the model for generating this, the invalid samples (falling outside of the boundary) are scarce from the generated data with no validator.


# !pip install geopandas &> /dev/null
# !pip install realtabformer &> /dev/null
# !git clone https://github.com/joncutrer/geopandas-tutorial.git &> /dev/null
import geopandas
import seaborn as sns
import matplotlib.pyplot as plt
from realtabformer import REaLTabFormer
from realtabformer import rtf_validators as rtf_val
from shapely.geometry import Polygon, LineString, Point, MultiPolygon
from sklearn.datasets import fetch_california_housing


def plot_sf(data, samples, title=None):
    xlims = (-126, -113.5)
    ylims = (31, 43)
    bins = (50, 50)

    dd = samples.copy()
    pp = dd.loc[
        dd["Longitude"].between(data["Longitude"].min(), data["Longitude"].max()) &
        dd["Latitude"].between(data["Latitude"].min(), data["Latitude"].max())
    ]

    g = sns.JointGrid(data=pp, x="Longitude", y="Latitude", marginal_ticks=True)
    g.plot_joint(
        sns.histplot,
        bins=bins,
    )

    states[states['NAME'] == 'California'].boundary.plot(ax=g.ax_joint)
    g.ax_joint.set_xlim(*xlims)
    g.ax_joint.set_ylim(*ylims)

    g.plot_marginals(sns.histplot, element="step", color="#03012d")

    if title:
        g.ax_joint.set_title(title)

    plt.tight_layout()

# Get geographic files
states = geopandas.read_file('geopandas-tutorial/data/usa-states-census-2014.shp')
states = states.to_crs("EPSG:4326")  # GPS Projection

# Get the California housing dataset
data = fetch_california_housing(as_frame=True).frame

# We create a model with small epochs for the demo, default is 200.
rtf_model = REaLTabFormer(
    model_type="tabular",
    batch_size=64,
    epochs=10,
    gradient_accumulation_steps=4,
    logging_steps=100)

# Fit the specified model. We also reduce the num_bootstrap, default is 500.
rtf_model.fit(data, num_bootstrap=10)

# Save the trained model
rtf_model.save("rtf_model/")

# Sample raw data without validator
samples_raw = rtf_model.sample(n_samples=10240, gen_batch=512)

# Sample data with the geographic validator
obs_validator = rtf_val.ObservationValidator()
obs_validator.add_validator(
    "geo_validator",
    rtf_val.GeoValidator(
        MultiPolygon(states[states['NAME'] == 'California'].geometry[0])),
    ("Longitude", "Latitude")
)

samples_validated = rtf_model.sample(
    n_samples=10240, gen_batch=512,
    validator=obs_validator,
)

# Visualize the samples
plot_sf(data, samples_raw, title="Raw samples")
plot_sf(data, samples_validated, title="Validated samples")

Citation

Please cite our work if you use the REaLTabFormer in your projects or research.

@article{solatorio2023realtabformer,
  title={REaLTabFormer: Generating Realistic Relational and Tabular Data using Transformers},
  author={Solatorio, Aivin V. and Dupriez, Olivier},
  journal={arXiv preprint arXiv:2302.02041},
  year={2023}
}

Acknowledgments

We thank the World Bank-UNHCR Joint Data Center on Forced Displacement (JDC) for funding the project "Enhancing Responsible Microdata Access to Improve Policy and Response in Forced Displacement Situations" (KP-P174174-GINP-TF0B5124). A part of the fund went into supporting the development of the REaLTabFormer framework which was used to generate the synthetic population for research on disclosure risk and the mosaic effect.

We also send :hugs: to the HuggingFace :hugs: for all the open-sourced software they release. And to all open-sourced projects, thank you!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

realtabformer-0.2.4.tar.gz (49.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

realtabformer-0.2.4-py3-none-any.whl (49.9 kB view details)

Uploaded Python 3

File details

Details for the file realtabformer-0.2.4.tar.gz.

File metadata

  • Download URL: realtabformer-0.2.4.tar.gz
  • Upload date:
  • Size: 49.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.2 CPython/3.10.6 Darwin/25.2.0

File hashes

Hashes for realtabformer-0.2.4.tar.gz
Algorithm Hash digest
SHA256 533f04203aad1e4e1d6c569d4f63cff13be57728df6045a2576350dccb33488a
MD5 fdf64a9fc26293e1dcccd73e04de7648
BLAKE2b-256 00331d466b3f9b670cc4f440ad7b7e2b236dcedeeb133918e58421be6387e71e

See more details on using hashes here.

File details

Details for the file realtabformer-0.2.4-py3-none-any.whl.

File metadata

  • Download URL: realtabformer-0.2.4-py3-none-any.whl
  • Upload date:
  • Size: 49.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.2 CPython/3.10.6 Darwin/25.2.0

File hashes

Hashes for realtabformer-0.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 852436c5c82a0bf470ca7e9063e5a4f3e250b3ff5b9c8f6c50113c1e9ba76486
MD5 dabb94b9100c0a97dd3703cfe73ca2c0
BLAKE2b-256 8c185a4bc71a57e4923817b9962a9ab2e6f9e82360c7771af90b11eb5a6f249a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page