Skip to main content

Official implementation of "TabEBM: A Tabular Data Augmentation Method with Class-Specific Energy-Based Models", NeurIPS 2024.

Project description

TabEBM: A Tabular Data Augmentation Method with Distinct Class-Specific Energy-Based Models

Arxiv-Paper License: Apache License 2.0

Test In Colab Python 3.10+ Downloads

Official code for the paper "TabEBM: A Tabular Data Augmentation Method with Distinct Class-Specific Energy-Based Models", published in the Thirty-Eighth Annual Conference on Neural Information Processing Systems (NeurIPS 2024).

Authored by Andrei Margeloiu*, Xiangjian Jiang*, Nikola Simidjievski, Mateja Jamnik, University of Cambridge, UK

📌 Overview

image-20241001125640288

TL;DR: We introduce a high-performance tabular data augmentation method that is fast, requires no additional training, and can be applied to any downstream predictive model. The optimized implementation features advanced caching, GPU acceleration, and memory-efficient SGLD sampling.

Abstract: Data collection is often difficult in critical fields such as medicine, physics, and chemistry. As a result, classification methods usually perform poorly with these small datasets, leading to weak predictive performance. Increasing the training set with additional synthetic data, similar to data augmentation in images, is commonly believed to improve downstream classification performance. However, current tabular generative methods that learn either the joint distribution p(x,y) p(\mathbf{x}, y) or the class-conditional distribution p(x∣y) p(\mathbf{x} \mid y) often overfit on small datasets, resulting in poor-quality synthetic data, usually worsening classification performance compared to using real data alone. To solve these challenges, we introduce TabEBM, a novel class-conditional generative method using Energy-Based Models (EBMs). Unlike existing methods that use a shared model to approximate all class-conditional densities, our key innovation is to create distinct EBM generative models for each class, each modelling its class-specific data distribution individually. This approach creates robust energy landscapes, even in ambiguous class distributions. Our experiments show that TabEBM generates synthetic data with higher quality and better statistical fidelity than existing methods. When used for data augmentation, our synthetic data consistently improves the classification performance across diverse datasets of various sizes, especially small ones.

📖 Citation

For attribution in academic contexts, please cite this work as:

@article{margeloiu2024tabebm,
	title={TabEBM: A Tabular Data Augmentation Method with Distinct Class-Specific Energy-Based Models},
	author={Andrei Margeloiu and Xiangjian Jiang and Nikola Simidjievski and Mateja Jamnik},
	journal={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
	year={2024},
}

🔑 Features

  • Optimized TabEBM Implementation: High-performance TabEBM with advanced caching, memory optimizations, and GPU acceleration
  • Fast Synthetic Data Generation: Generate synthetic tabular data with minimal configuration and improved speed
  • TabPFN-v2 Integration: Seamless integration with the latest TabPFN-v2 for enhanced gradient-based sampling
  • Memory-Efficient SGLD: Optimized Stochastic Gradient Langevin Dynamics with pre-computed noise tensors
  • Comprehensive Tutorials: Three interactive notebooks covering data generation, real-world augmentation, and density analysis

🔥 Performance Optimizations

  • Model Caching: Intelligent caching system to avoid redundant model training across classes
  • Vectorized Operations: Optimized tensor operations and reduced device transfers for faster computation
  • Stratified Sampling: Smart subsampling for large datasets while maintaining class balance
  • Gradient Computation: Enhanced gradient-based sampling with TabPFN-v2's energy landscapes
  • Memory Management: Pre-allocated tensors and efficient memory usage patterns

🚀 Installation

Quick Installation (Optimized Version)

pip install tabebm

This installs the latest optimized version with TabPFN-v2 integration, GPU acceleration, and performance enhancements.

To reproduce the results reported in the paper

  • Create conda environment
conda create -n tabebm python=3.10.12
conda activate tabebm
  • Install tabebm and dependencies
git clone https://github.com/andreimargeloiu/TabEBM
cd TabEBM/
pip install --no-cache-dir -r requirements_paper.txt
pip install .

💥 Running Experiments with TabEBM

⚡ Quick Start

The optimized TabEBM implementation provides a streamlined API for high-performance synthetic data generation:

from tabebm.TabEBM import TabEBM

# Initialize with optimized configuration
tabebm = TabEBM(max_data_size=10000)  # Automatic GPU detection and caching

# Generate synthetic data with enhanced performance
augmented_data = tabebm.generate(
    X_train, y_train, 
    num_samples=100,
    sgld_steps=200,    # Optimized SGLD with pre-computed noise
    debug=True         # Monitor optimization progress
)

# Output format: {class_id: numpy.ndarray} for each class
# augmented_data['class_0'] = Generated samples for class 0
# augmented_data['class_1'] = Generated samples for class 1

📚 Interactive Tutorials

  • Test In Colab Tutorial 1: Generate synthetic data with TabEBM

    • The library can generate synthetic data with three lines of code.

      from tabebm.TabEBM import TabEBM
      
      tabebm = TabEBM()
      augmented_data = tabebm.generate(X_train, y_train, num_samples=100)
      
      # Output:
      # augmented_data[class_id] = numpy.ndarray of generated data for a specific ’’class_id‘‘
      
  • Test In Colab Tutorial 2: Augment real-world data with TabEBM

    • We provide a click-to-run example of using TabEBM to augment a real-world datasets for improved downstream performance.
  • Test In Colab Tutorial 3: Analyse the learned data distribution by TabEBM

    • The library allows computation of TabEBM’s energy function and the unnormalised data density.

      from tabebm.TabEBM import plot_TabEBM_energy_contour
      
      X, y = circles_dataset(n_samples=300, noise=2)
      plot_tabebm_probabilities(X, y, title_prefix='(noise=2)', h=0.2)
      plt.show()
      

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tabebm-2025.8.19.tar.gz (19.2 kB view details)

Uploaded Source

File details

Details for the file tabebm-2025.8.19.tar.gz.

File metadata

  • Download URL: tabebm-2025.8.19.tar.gz
  • Upload date:
  • Size: 19.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.18

File hashes

Hashes for tabebm-2025.8.19.tar.gz
Algorithm Hash digest
SHA256 6111611326747a680f93dfadcbac1d602ce20cb722b9b6cbff1f556b9f48d503
MD5 fff1e8a751ed59265b64f229a8a88711
BLAKE2b-256 5d95a6f327ed88907ca68a7e49e65c9bebe3d9f35b3a623e1c8b89b15597c094

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page