Skip to main content

Official implementation of "TabEBM: A Tabular Data Augmentation Method with Class-Specific Energy-Based Models", NeurIPS 2024.

Project description

TabEBM: A Tabular Data Augmentation Method with Distinct Class-Specific Energy-Based Models

Arxiv-Paper License: Apache License 2.0

Test In Colab Python 3.10+ Downloads

Official code for the paper "TabEBM: A Tabular Data Augmentation Method with Distinct Class-Specific Energy-Based Models", published in the Thirty-Eighth Annual Conference on Neural Information Processing Systems (NeurIPS 2024).

Authored by Andrei Margeloiu, Xiangjian Jiang, Nikola Simidjievski, Mateja Jamnik, University of Cambridge, UK

📌 Overview

image-20241001125640288

TL;DR: We introduce a new data augmentation method for tabular data, which is fast, requires no additional training, and can be applied to any downstream predictors model.

Abstract: Data collection is often difficult in critical fields such as medicine, physics, and chemistry. As a result, classification methods usually perform poorly with these small datasets, leading to weak predictive performance. Increasing the training set with additional synthetic data, similar to data augmentation in images, is commonly believed to improve downstream classification performance. However, current tabular generative methods that learn either the joint distribution p(x,y) p(\mathbf{x}, y) or the class-conditional distribution p(x∣y) p(\mathbf{x} \mid y) often overfit on small datasets, resulting in poor-quality synthetic data, usually worsening classification performance compared to using real data alone. To solve these challenges, we introduce TabEBM, a novel class-conditional generative method using Energy-Based Models (EBMs). Unlike existing methods that use a shared model to approximate all class-conditional densities, our key innovation is to create distinct EBM generative models for each class, each modelling its class-specific data distribution individually. This approach creates robust energy landscapes, even in ambiguous class distributions. Our experiments show that TabEBM generates synthetic data with higher quality and better statistical fidelity than existing methods. When used for data augmentation, our synthetic data consistently improves the classification performance across diverse datasets of various sizes, especially small ones.

📖 Citation

For attribution in academic contexts, please cite this work as:

@article{margeloiu2024tabebm,
	title={TabEBM: A Tabular Data Augmentation Method with Distinct Class-Specific Energy-Based Models},
	author={Andrei Margeloiu and Xiangjian Jiang and Nikola Simidjievski and Mateja Jamnik},
	journal={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
	year={2024},
}

🔑 Features

  • TabEBM.py contains (i) the implementation of TabEBM, and (ii) a helper function plot_TabEBM_energy_contour to show the energy contour (or unnormalized probability) approximated by TabEBM
  • TabEBM_approximated_density.ipynb shows the TabEBM approximation of the density of the real data distribution
  • TabEBM_generate_data.ipynb shows how to generate data using TabEBM

🚀 Installation

pip install tabebm

To reproduce the results reported in the paper

  • Create conda environment
conda create -n tabebm python=3.10.12
conda activate tabebm
  • Install tabebm and dependencies
git clone https://github.com/andreimargeloiu/TabEBM
cd TabEBM/
pip install --no-cache-dir -r requirements_paper.txt
pip install .

💥 Running Experiments with TabEBM

  • Test In Colab Tutorial 1: Generate synthetic data with TabEBM

    • The library can generate synthetic data with three lines of code.

      from tabebm.TabEBM import TabEBM
      
      tabebm = TabEBM()
      augmented_data = tabebm.generate(X_train, y_train, num_samples=100)
      
      # Output:
      # augmented_data[class_id] = numpy.ndarray of generated data for a specific ’’class_id‘‘
      
  • Test In Colab Tutorial 2: Analyse the learned data distribution by TabEBM

    • The library allows computation of TabEBM’s energy function and the unnormalised data density.

      from tabebm.TabEBM import plot_TabEBM_energy_contour
      
      X, y = circles_dataset(n_samples=300, noise=2)
      plot_tabebm_probabilities(X, y, title_prefix='(noise=2)', h=0.2)
      plt.show()
      
  • Test In Colab Tutorial 3: Augment real-world data with TabEBM

    • We provide a minimal example of using TabEBM to augment a real-world datasets for improvied downstream performance.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tabebm-2024.10.6.tar.gz (15.5 kB view details)

Uploaded Source

File details

Details for the file tabebm-2024.10.6.tar.gz.

File metadata

  • Download URL: tabebm-2024.10.6.tar.gz
  • Upload date:
  • Size: 15.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.13

File hashes

Hashes for tabebm-2024.10.6.tar.gz
Algorithm Hash digest
SHA256 563148761f305d29050dc4160d936361ff4d0683d53bab7f3947cb221b4f0a17
MD5 b56a9c19f73afcf77fa0e909da45053d
BLAKE2b-256 24d1395c877641766976711479c6e672aeebe95a8b7781da7946fb7abcee1bd4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page