Skip to main content

A framework for multimodal fusion of tabular and text data.

Project description

TabularTextMultimodalFusion

A unified framework for experimenting with various architectures that combine tabular data (numerical + categorical) and textual data using pretrained language models (e.g., BERT/DistilBERT).

Inspired by and extending the ideas in TabularTextTransformer, this repo explores fusion architectures, contrastive learning, and graph-based methods for multimodal classification.


๐Ÿ”ง Features

  • Multimodal Fusion: Cross-attention, skip connections, late fusion, GAT-based fusion
  • Advanced Encodings: Custom numerical encodings (RBF, Fourier, Chebyshev, Sigmoid, Positional vectors)
  • Graph Neural Networks: Graph-based multimodal GNN via torch_geometric
  • Contrastive Learning: Multiple contrastive loss variants (MMD, MINE, InfoNCE)
  • Comprehensive Benchmarking: Multiple datasets with standardized preprocessing
  • Unified Framework: Consistent API for all model architectures

๐Ÿ“ฆ Installation

Option 1: Install from PyPI (Recommended for Users)

Prerequisites: Install PyTorch and torch-geometric first, then install the package.

# 1. Install PyTorch with CUDA support (adjust CUDA version as needed)
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121

# 2. Install torch-geometric dependencies
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric \
    --find-links https://data.pyg.org/whl/torch-2.1.0+cu121.html

# 3. Install this package
pip install tabulartextmultimodalfusion

Quick Start:

# Import models
from tabulartextmultimodalfusion.models import (
    CrossAttention,
    CombinedModelConcat4,
    TabularEmbedding
)

# Import dataset utilities
from tabulartextmultimodalfusion.dataset import (
    prepareTensorDatasetWithTokenizer,
    preprocess_dataset
)

# Import settings
from tabulartextmultimodalfusion.settings import load_settings

See example.py for a complete working example with synthetic data, model initialization, and training.


Option 2: Development Installation (For Contributors/Researchers)

For development, experiments, or contributing to the project:

Method A: Using Conda (Recommended)

# 1. Clone the repository
git clone https://github.com/nadav22799/TabularTextMultimodalFusion
cd TabularTextMultimodalFusion

# 2. Create conda environment with all dependencies
conda env create -f environment.yaml

# 3. Activate the environment
conda activate TTMF

# 4. Install package in editable mode
pip install -e .

# 5. Run example to verify installation
python example.py

Method B: Using pip only

# 1. Clone the repository
git clone https://github.com/nadav22799/TabularTextMultimodalFusion
cd TabularTextMultimodalFusion

# 2. Install dependencies
pip install -r requirements.txt

# 3. Install package in editable mode
pip install -e .

# 4. Run example to verify installation
python example.py

๐ŸŽฏ Model Selection Guide

Base Model Architectures

Choose from the following model families based on your use case:

๐Ÿ”ฅ Cross-Attention Models (Recommended)

Our proposed approaches for optimal text-tabular fusion:

  • CrossAttention: Core cross-attention mechanism between text and tabular features
  • CrossAttentionSkipNet: Cross-attention enhanced with skip connections for better gradient flow

๐Ÿ”„ Fusion-Based Models

Alternative fusion strategies:

  • FusionSkipNet: Skip connections with feature fusion
  • CombinedModelGAT: Graph Attention Network for combined feature processing

๐Ÿค– BERT-Based Approaches

Different strategies for incorporating BERT:

  • LateFuseBERT: Late fusion of BERT text embeddings with tabular features
  • AllTextBERT: Converts tabular data to text for unified BERT processing
  • TabularForBert: Tabular data preprocessing optimized for BERT compatibility
  • BertWithTabular: BERT with additional tabular feature processing layers

๐Ÿ“Š Single-Modality Baselines

For comparison and ablation studies:

  • OnlyTabular: Tabular data only (MLP-based)
  • OnlyText: Text data only (BERT-based)

Configuration Options

Fusion Methods

Control how text and tabular features are combined:

# Without BERT self-attention on final embeddings
fusion_methods = ['Concat2', 'Concat4', 'SumW2', 'SumW4']

# With BERT self-attention on final embeddings (suffix 's')
fusion_methods = ['Concat2s', 'Concat4s', 'SumW2s', 'SumW4s']
  • Concat: Concatenation fusion (2 = 2x dims, 4 = 4x dims)
  • SumW: Weighted sum fusion (2 = 2x dims, 4 = 4x dims)
  • 's' suffix: Applies BERT self-attention on final token embeddings

Numerical Encoders

Transform numerical tabular features for better cross-modal alignment:

  • Fourier: Fourier feature encoding for periodic patterns
  • FourierVec: Vectorized Fourier encoding
  • PosEnVec: Positional encoding vectors
  • RBF: Radial Basis Function encoding for non-linear relationships
  • RBFVec: Vectorized RBF encoding
  • Sigmoid: Sigmoid transformation for bounded features
  • Chebyshev: Chebyshev polynomial encoding

Loss Functions

Optimize cross-modal representation learning:

  • MMD: Maximum Mean Discrepancy for distribution alignment
  • MINE: Mutual Information Neural Estimation
  • InfoNCE: Info Noise Contrastive Estimation
  • Contrastive: Standard contrastive learning loss

Model Naming Convention

Models follow the pattern: {BaseModel}{FusionMethod}[{NumericalEncoder}][{LossFunction}]

Examples:

  • CrossAttentionConcat4: Cross-attention with 4D concatenation fusion
  • CrossAttentionConcat4s: Same as above but with self-attention on final embeddings
  • CrossAttentionConcat4Fourier: Cross-attention + Concat4 + Fourier encoding
  • CrossAttentionConcat4MMD: Cross-attention + Concat4 + MMD loss

๐Ÿ’ก Quick Start Recommendations

Use Case Recommended Model Why
Best Overall Performance CrossAttentionConcat4s Optimal fusion with self-attention
Limited Compute CrossAttentionConcat2 Smaller feature dimensions
Periodic/Seasonal Data CrossAttentionConcat4Fourier Fourier encoding for patterns
High-Dimensional Tabular CrossAttentionConcat4RBF RBF handles complex relationships
Distribution Alignment CrossAttentionConcat4MMD MMD loss for better alignment
Baseline Comparison OnlyText, OnlyTabular Single-modality benchmarks

๐Ÿš€ Running Experiments

Quick Start

Run experiments using CLI arguments:

# Run experiment 1 (architecture comparison)
python src/main.py --version exp1

# Run experiment 2 (numerical encoders)
python src/main.py --version exp2

# Run experiment 3 (loss functions)
python src/main.py --version exp3

# Run MIMIC experiments
python src/main_mimic.py --version exp1

Customization Options

1. Version Selection (CLI)

Choose experiment type via command line:

python src/main.py --version exp1  # Architecture comparison
python src/main.py --version exp2  # Numerical encoder comparison
python src/main.py --version exp3  # Loss function comparison

2. Dataset Selection (Manual)

Edit src/main.py to customize datasets:

DATASETS = ["wine_10", "airbnb", "kick"]  # Select from supported datasets

3. Model Selection (Automatic by Version)

Models are automatically selected based on the version:

  • exp1: Tests all architecture variants and baselines
  • exp2: Tests numerical encoders with CrossAttentionConcat4
  • exp3: Tests loss functions with CrossAttentionConcat4

Or manually override in src/main.py:

if args.version == "exp1":
    MODELS = ["CrossAttentionConcat4s", "BertWithTabular"]  # Custom selection

๐Ÿงช Experiment Types

Version Focus Models Compared
exp1 Architecture Comparison All fusion architectures vs baselines
exp2 Numerical Encoders Different encoders with best architecture
exp3 Loss Functions Contrastive learning variants

Experiment 1: Architecture Comparison

Tests fundamental fusion approaches:

MODELS = [
    # Our proposed methods
    "CrossAttentionSumW4", "CrossAttentionConcat4", 
    "CrossAttentionConcat4s", "CrossAttentionSumW4s",
    # Alternative fusion
    "FusionSkipNet", "CombinedModelGAT",
    # BERT variants
    "BertWithTabular", "LateFuseBERT", "AllTextBERT",
    # Baselines
    "OnlyTabular", "OnlyText"
]

Experiment 2: Numerical Encoder Ablation

Uses best architecture (CrossAttentionConcat4) with different encoders:

MODELS = [
    "CrossAttentionConcat4Fourier", "CrossAttentionConcat4RBF",
    "CrossAttentionConcat4FourierVec", "CrossAttentionConcat4PosEnVec",
    "CrossAttentionConcat4Chebyshev", "CrossAttentionConcat4Sigmoid"
]

Experiment 3: Loss Function Comparison

Tests contrastive learning approaches:

MODELS = [
    "CrossAttentionConcat4MMD", "CrossAttentionConcat4MINE",
    "CrossAttentionConcat4InfoNCE", "CrossAttentionConcat4Contrastive"
]

๐Ÿ“ Package Structure

TabularTextMultimodalFusion/
โ”œโ”€โ”€ src/
โ”‚   โ”œโ”€โ”€ tabulartextmultimodalfusion/    # Main package
โ”‚   โ”‚   โ”œโ”€โ”€ __init__.py
โ”‚   โ”‚   โ”œโ”€โ”€ models.py                   # Model architectures
โ”‚   โ”‚   โ”œโ”€โ”€ dataset.py                  # Data loading and preprocessing
โ”‚   โ”‚   โ”œโ”€โ”€ settings.py                 # Configuration
โ”‚   โ”‚   โ”œโ”€โ”€ optimization.py             # Training and optimization
โ”‚   โ”‚   โ”œโ”€โ”€ GridSearch.py               # Hyperparameter tuning
โ”‚   โ”‚   โ”œโ”€โ”€ load_mimic.py               # MIMIC dataset loader
โ”‚   โ”‚   โ””โ”€โ”€ mimic_utils.py              # MIMIC utilities
โ”‚   โ”œโ”€โ”€ main.py                         # Experiment runner
โ”‚   โ”œโ”€โ”€ main_mimic.py                   # MIMIC experiment runner
โ”‚   โ””โ”€โ”€ mimic_pretrain.yaml             # MIMIC configuration
โ”œโ”€โ”€ example.py           # Complete working example (Quick Start)
โ”œโ”€โ”€ environment.yaml     # Conda environment
โ”œโ”€โ”€ requirements.txt     # Dependencies
โ”œโ”€โ”€ setup.py            # Package setup
โ”œโ”€โ”€ LICENSE             # MIT License
โ””โ”€โ”€ README.md           # This file

๐Ÿ“š Supported Datasets

Dataset Directory Setup

Create a datasets/ directory in the project root and place all dataset files there with their respective filenames as shown below:

mkdir datasets
# Download datasets and place them in the datasets/ directory

Dataset Table

Dataset Name Filename URL
airbnb cleansed_listings_dec18.csv https://www.kaggle.com/datasets/tylerx/melbourne-airbnb-open-dataairbnb-listings-in-major-us-cities-deloitte-ml
kick kickstarter_train.csv https://www.kaggle.com/datasets/codename007/funding-successful-projects?select=train.csv
cloth Womens Clothing E-Commerce Reviews.csv https://www.kaggle.com/datasets/nicapotato/womens-ecommerce-clothing-reviews
wine_10 / wine_100 winemag-data-130k-v2.csv https://www.kaggle.com/datasets/zynicide/wine-reviews
income adult.csv https://www.kaggle.com/datasets/uciml/adult-census-income
pet petfinder_train.csv https://www.kaggle.com/competitions/petfinder-adoption-prediction/data
jigsaw jigsaw_train_100k.csv https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification
mimic Special - See below https://physionet.org/content/mimiciv/3.1/, https://physionet.org/content/mimic-iv-note/2.2/

MIMIC Dataset

The MIMIC-IV dataset requires special handling. The configuration and MIMIC utilities (load_mimic.py, mimic_utils.py) are adapted from https://github.com/google-research/lanistr.

  1. Access: Request access to MIMIC-IV at https://physionet.org/content/mimiciv/3.1/, https://physionet.org/content/mimic-iv-note/2.2/
  2. Download: Download the MIMIC-IV dataset following PhysioNet instructions
  3. Preprocessing: Follow the MedFuse extraction pipeline at https://github.com/nyuad-cai/MedFuse/tree/main/mimic4extract
  4. Configuration: Update the paths in src/mimic_pretrain.yaml to point to where you extracted the MIMIC-IV and MIMIC-IV-Note data (follow the same structure as used in the Lanistr repository). Replace YOUR_PATH placeholders with your actual data directories:
    • task_data_dir: Path to extracted in-hospital-mortality data
    • unimodal_data_dir: Path to unimodal data directory
    • preprocessed_data_dir: Path to preprocessed data
    • normalizer_file: Path to normalizer.csv
    • discretizer_config_path: Path to discretizer_config.json
  5. Run: Use python src/main_mimic.py --version exp1 for MIMIC experiments

Adding Custom Datasets

You can add any custom dataset by following these steps:

  1. Add your dataset file to the datasets/ directory
  2. Define dataset settings in src/tabulartextmultimodalfusion/settings.py by adding a new configuration block
  3. Implement data loading in src/tabulartextmultimodalfusion/dataset.py following existing preprocessing patterns
  4. Update your experiments to include the new dataset name

๐Ÿ” Results and Analysis

Performance Metrics

  • Accuracy: Overall classification performance
  • F1-Score: Balanced precision and recall
  • AUC-ROC: Area under the ROC curve (binary classification)
  • Training Time: Computational efficiency

Expected Findings

  • Cross-attention models typically outperform simple fusion baselines
  • Numerical encoders provide significant improvements for datasets with complex numerical relationships
  • Contrastive losses help when text and tabular modalities have different distributions

๐Ÿ™ Attribution

Parts of the preprocessing pipeline (settings.py, dataset.py) are adapted from:

Yury Petyushin, Tabular Text Transformer, MIT License

We thank the original authors for their valuable contribution. This project modifies and builds upon that work with new architectures and optimization strategies.


๐Ÿ“œ License

This project is licensed under the MIT License. See the LICENSE file for details.


๐Ÿค Contributing

Contributions are welcome! Here's how you can help:

  • ๐Ÿ› Bug Reports: Open an issue with detailed reproduction steps
  • ๐Ÿ’ก Feature Requests: Suggest new fusion strategies or loss functions
  • ๐Ÿ“Š New Datasets: Add support for additional multimodal datasets
  • ๐Ÿ”ง Code Improvements: Submit pull requests for optimizations

Development Guidelines

  1. Follow existing code style and documentation patterns
  2. Add tests for new model architectures
  3. Update documentation for new features
  4. Ensure reproducibility with fixed random seeds

๐Ÿ—บ๏ธ Roadmap

  • Transformer-based Fusion: Implement transformer layers for cross-modal attention
  • Multi-task Learning: Support for multiple prediction tasks
  • Hyperparameter Optimization: Automated hyperparameter tuning
  • Model Interpretability: Attention visualization and feature importance
  • Distributed Training: Multi-GPU support for large-scale experiments
  • Pre-trained Models: Release pre-trained checkpoints for common datasets

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tabulartextmultimodalfusion-0.1.5.tar.gz (52.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tabulartextmultimodalfusion-0.1.5-py3-none-any.whl (50.7 kB view details)

Uploaded Python 3

File details

Details for the file tabulartextmultimodalfusion-0.1.5.tar.gz.

File metadata

File hashes

Hashes for tabulartextmultimodalfusion-0.1.5.tar.gz
Algorithm Hash digest
SHA256 70fa6cbcb69a5e474b1821fcc72846726dae564b3a0f57ab61f42e4e3c18719a
MD5 bd682e953a6db5ddacb4f3de1d3f5ed1
BLAKE2b-256 438f86b3a8a1f4a0508fbddf6afcf5b624509b24e0d515f2aea19cf25471c176

See more details on using hashes here.

File details

Details for the file tabulartextmultimodalfusion-0.1.5-py3-none-any.whl.

File metadata

File hashes

Hashes for tabulartextmultimodalfusion-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 f9b1734c480d995ea7ac34ad120b1a287aa4c071a884731c0acb8262c37d08c1
MD5 05f22a8ee8faf9d6de5162d7b922b46a
BLAKE2b-256 7791599742d119ffaf18732d76e94f63c8e8f19f183af50e458ee406cd60d60b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page