A framework for multimodal fusion of tabular and text data.
Project description
TabularTextMultimodalFusion
A unified framework for experimenting with various architectures that combine tabular data (numerical + categorical) and textual data using pretrained language models (e.g., BERT/DistilBERT).
Inspired by and extending the ideas in TabularTextTransformer, this repo explores fusion architectures, contrastive learning, and graph-based methods for multimodal classification.
๐ง Features
- Multimodal Fusion: Cross-attention, skip connections, late fusion, GAT-based fusion
- Advanced Encodings: Custom numerical encodings (RBF, Fourier, Chebyshev, Sigmoid, Positional vectors)
- Graph Neural Networks: Graph-based multimodal GNN via
torch_geometric - Contrastive Learning: Multiple contrastive loss variants (MMD, MINE, InfoNCE)
- Comprehensive Benchmarking: Multiple datasets with standardized preprocessing
- Unified Framework: Consistent API for all model architectures
๐ฆ Installation
Option 1: Install from PyPI (Recommended for Users)
Prerequisites: Install PyTorch and torch-geometric first, then install the package.
# 1. Install PyTorch with CUDA support (adjust CUDA version as needed)
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
# 2. Install torch-geometric dependencies
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric \
--find-links https://data.pyg.org/whl/torch-2.1.0+cu121.html
# 3. Install this package
pip install tabulartextmultimodalfusion
Quick Start:
# Import models
from tabulartextmultimodalfusion.models import (
CrossAttention,
CombinedModelConcat4,
TabularEmbedding
)
# Import dataset utilities
from tabulartextmultimodalfusion.dataset import (
prepareTensorDatasetWithTokenizer,
preprocess_dataset
)
# Import settings
from tabulartextmultimodalfusion.settings import load_settings
See example.py for a complete working example with synthetic data, model initialization, and training.
Option 2: Development Installation (For Contributors/Researchers)
For development, experiments, or contributing to the project:
Method A: Using Conda (Recommended)
# 1. Clone the repository
git clone https://github.com/nadav22799/TabularTextMultimodalFusion
cd TabularTextMultimodalFusion
# 2. Create conda environment with all dependencies
conda env create -f environment.yaml
# 3. Activate the environment
conda activate TTMF
# 4. Install package in editable mode
pip install -e .
# 5. Run example to verify installation
python example.py
Method B: Using pip only
# 1. Clone the repository
git clone https://github.com/nadav22799/TabularTextMultimodalFusion
cd TabularTextMultimodalFusion
# 2. Install dependencies
pip install -r requirements.txt
# 3. Install package in editable mode
pip install -e .
# 4. Run example to verify installation
python example.py
๐ฏ Model Selection Guide
Base Model Architectures
Choose from the following model families based on your use case:
๐ฅ Cross-Attention Models (Recommended)
Our proposed approaches for optimal text-tabular fusion:
CrossAttention: Core cross-attention mechanism between text and tabular featuresCrossAttentionSkipNet: Cross-attention enhanced with skip connections for better gradient flow
๐ Fusion-Based Models
Alternative fusion strategies:
FusionSkipNet: Skip connections with feature fusionCombinedModelGAT: Graph Attention Network for combined feature processing
๐ค BERT-Based Approaches
Different strategies for incorporating BERT:
LateFuseBERT: Late fusion of BERT text embeddings with tabular featuresAllTextBERT: Converts tabular data to text for unified BERT processingTabularForBert: Tabular data preprocessing optimized for BERT compatibilityBertWithTabular: BERT with additional tabular feature processing layers
๐ Single-Modality Baselines
For comparison and ablation studies:
OnlyTabular: Tabular data only (MLP-based)OnlyText: Text data only (BERT-based)
Configuration Options
Fusion Methods
Control how text and tabular features are combined:
# Without BERT self-attention on final embeddings
fusion_methods = ['Concat2', 'Concat4', 'SumW2', 'SumW4']
# With BERT self-attention on final embeddings (suffix 's')
fusion_methods = ['Concat2s', 'Concat4s', 'SumW2s', 'SumW4s']
- Concat: Concatenation fusion (2 = 2x dims, 4 = 4x dims)
- SumW: Weighted sum fusion (2 = 2x dims, 4 = 4x dims)
- 's' suffix: Applies BERT self-attention on final token embeddings
Numerical Encoders
Transform numerical tabular features for better cross-modal alignment:
Fourier: Fourier feature encoding for periodic patternsFourierVec: Vectorized Fourier encodingPosEnVec: Positional encoding vectorsRBF: Radial Basis Function encoding for non-linear relationshipsRBFVec: Vectorized RBF encodingSigmoid: Sigmoid transformation for bounded featuresChebyshev: Chebyshev polynomial encoding
Loss Functions
Optimize cross-modal representation learning:
MMD: Maximum Mean Discrepancy for distribution alignmentMINE: Mutual Information Neural EstimationInfoNCE: Info Noise Contrastive EstimationContrastive: Standard contrastive learning loss
Model Naming Convention
Models follow the pattern: {BaseModel}{FusionMethod}[{NumericalEncoder}][{LossFunction}]
Examples:
CrossAttentionConcat4: Cross-attention with 4D concatenation fusionCrossAttentionConcat4s: Same as above but with self-attention on final embeddingsCrossAttentionConcat4Fourier: Cross-attention + Concat4 + Fourier encodingCrossAttentionConcat4MMD: Cross-attention + Concat4 + MMD loss
๐ก Quick Start Recommendations
| Use Case | Recommended Model | Why |
|---|---|---|
| Best Overall Performance | CrossAttentionConcat4s |
Optimal fusion with self-attention |
| Limited Compute | CrossAttentionConcat2 |
Smaller feature dimensions |
| Periodic/Seasonal Data | CrossAttentionConcat4Fourier |
Fourier encoding for patterns |
| High-Dimensional Tabular | CrossAttentionConcat4RBF |
RBF handles complex relationships |
| Distribution Alignment | CrossAttentionConcat4MMD |
MMD loss for better alignment |
| Baseline Comparison | OnlyText, OnlyTabular |
Single-modality benchmarks |
๐ Running Experiments
Quick Start
Run experiments using CLI arguments:
# Run experiment 1 (architecture comparison)
python src/main.py --version exp1
# Run experiment 2 (numerical encoders)
python src/main.py --version exp2
# Run experiment 3 (loss functions)
python src/main.py --version exp3
# Run MIMIC experiments
python src/main_mimic.py --version exp1
Customization Options
1. Version Selection (CLI)
Choose experiment type via command line:
python src/main.py --version exp1 # Architecture comparison
python src/main.py --version exp2 # Numerical encoder comparison
python src/main.py --version exp3 # Loss function comparison
2. Dataset Selection (Manual)
Edit src/main.py to customize datasets:
DATASETS = ["wine_10", "airbnb", "kick"] # Select from supported datasets
3. Model Selection (Automatic by Version)
Models are automatically selected based on the version:
exp1: Tests all architecture variants and baselinesexp2: Tests numerical encoders withCrossAttentionConcat4exp3: Tests loss functions withCrossAttentionConcat4
Or manually override in src/main.py:
if args.version == "exp1":
MODELS = ["CrossAttentionConcat4s", "BertWithTabular"] # Custom selection
๐งช Experiment Types
| Version | Focus | Models Compared |
|---|---|---|
exp1 |
Architecture Comparison | All fusion architectures vs baselines |
exp2 |
Numerical Encoders | Different encoders with best architecture |
exp3 |
Loss Functions | Contrastive learning variants |
Experiment 1: Architecture Comparison
Tests fundamental fusion approaches:
MODELS = [
# Our proposed methods
"CrossAttentionSumW4", "CrossAttentionConcat4",
"CrossAttentionConcat4s", "CrossAttentionSumW4s",
# Alternative fusion
"FusionSkipNet", "CombinedModelGAT",
# BERT variants
"BertWithTabular", "LateFuseBERT", "AllTextBERT",
# Baselines
"OnlyTabular", "OnlyText"
]
Experiment 2: Numerical Encoder Ablation
Uses best architecture (CrossAttentionConcat4) with different encoders:
MODELS = [
"CrossAttentionConcat4Fourier", "CrossAttentionConcat4RBF",
"CrossAttentionConcat4FourierVec", "CrossAttentionConcat4PosEnVec",
"CrossAttentionConcat4Chebyshev", "CrossAttentionConcat4Sigmoid"
]
Experiment 3: Loss Function Comparison
Tests contrastive learning approaches:
MODELS = [
"CrossAttentionConcat4MMD", "CrossAttentionConcat4MINE",
"CrossAttentionConcat4InfoNCE", "CrossAttentionConcat4Contrastive"
]
๐ Package Structure
TabularTextMultimodalFusion/
โโโ src/
โ โโโ tabulartextmultimodalfusion/ # Main package
โ โ โโโ __init__.py
โ โ โโโ models.py # Model architectures
โ โ โโโ dataset.py # Data loading and preprocessing
โ โ โโโ settings.py # Configuration
โ โ โโโ optimization.py # Training and optimization
โ โ โโโ GridSearch.py # Hyperparameter tuning
โ โ โโโ load_mimic.py # MIMIC dataset loader
โ โ โโโ mimic_utils.py # MIMIC utilities
โ โโโ main.py # Experiment runner
โ โโโ main_mimic.py # MIMIC experiment runner
โ โโโ mimic_pretrain.yaml # MIMIC configuration
โโโ example.py # Complete working example (Quick Start)
โโโ environment.yaml # Conda environment
โโโ requirements.txt # Dependencies
โโโ setup.py # Package setup
โโโ LICENSE # MIT License
โโโ README.md # This file
๐ Supported Datasets
Dataset Directory Setup
Create a datasets/ directory in the project root and place all dataset files there with their respective filenames as shown below:
mkdir datasets
# Download datasets and place them in the datasets/ directory
Dataset Table
| Dataset Name | Filename | URL |
|---|---|---|
airbnb |
cleansed_listings_dec18.csv |
https://www.kaggle.com/datasets/tylerx/melbourne-airbnb-open-dataairbnb-listings-in-major-us-cities-deloitte-ml |
kick |
kickstarter_train.csv |
https://www.kaggle.com/datasets/codename007/funding-successful-projects?select=train.csv |
cloth |
Womens Clothing E-Commerce Reviews.csv |
https://www.kaggle.com/datasets/nicapotato/womens-ecommerce-clothing-reviews |
wine_10 / wine_100 |
winemag-data-130k-v2.csv |
https://www.kaggle.com/datasets/zynicide/wine-reviews |
income |
adult.csv |
https://www.kaggle.com/datasets/uciml/adult-census-income |
pet |
petfinder_train.csv |
https://www.kaggle.com/competitions/petfinder-adoption-prediction/data |
jigsaw |
jigsaw_train_100k.csv |
https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification |
mimic |
Special - See below | https://physionet.org/content/mimiciv/3.1/, https://physionet.org/content/mimic-iv-note/2.2/ |
MIMIC Dataset
The MIMIC-IV dataset requires special handling. The configuration and MIMIC utilities (load_mimic.py, mimic_utils.py) are adapted from https://github.com/google-research/lanistr.
- Access: Request access to MIMIC-IV at https://physionet.org/content/mimiciv/3.1/, https://physionet.org/content/mimic-iv-note/2.2/
- Download: Download the MIMIC-IV dataset following PhysioNet instructions
- Preprocessing: Follow the MedFuse extraction pipeline at https://github.com/nyuad-cai/MedFuse/tree/main/mimic4extract
- Configuration: Update the paths in
src/mimic_pretrain.yamlto point to where you extracted the MIMIC-IV and MIMIC-IV-Note data (follow the same structure as used in the Lanistr repository). ReplaceYOUR_PATHplaceholders with your actual data directories:task_data_dir: Path to extracted in-hospital-mortality dataunimodal_data_dir: Path to unimodal data directorypreprocessed_data_dir: Path to preprocessed datanormalizer_file: Path to normalizer.csvdiscretizer_config_path: Path to discretizer_config.json
- Run: Use
python src/main_mimic.py --version exp1for MIMIC experiments
Adding Custom Datasets
You can add any custom dataset by following these steps:
- Add your dataset file to the
datasets/directory - Define dataset settings in
src/tabulartextmultimodalfusion/settings.pyby adding a new configuration block - Implement data loading in
src/tabulartextmultimodalfusion/dataset.pyfollowing existing preprocessing patterns - Update your experiments to include the new dataset name
๐ Results and Analysis
Performance Metrics
- Accuracy: Overall classification performance
- F1-Score: Balanced precision and recall
- AUC-ROC: Area under the ROC curve (binary classification)
- Training Time: Computational efficiency
Expected Findings
- Cross-attention models typically outperform simple fusion baselines
- Numerical encoders provide significant improvements for datasets with complex numerical relationships
- Contrastive losses help when text and tabular modalities have different distributions
๐ Attribution
Parts of the preprocessing pipeline (settings.py, dataset.py) are adapted from:
Yury Petyushin, Tabular Text Transformer, MIT License
We thank the original authors for their valuable contribution. This project modifies and builds upon that work with new architectures and optimization strategies.
๐ License
This project is licensed under the MIT License. See the LICENSE file for details.
๐ค Contributing
Contributions are welcome! Here's how you can help:
- ๐ Bug Reports: Open an issue with detailed reproduction steps
- ๐ก Feature Requests: Suggest new fusion strategies or loss functions
- ๐ New Datasets: Add support for additional multimodal datasets
- ๐ง Code Improvements: Submit pull requests for optimizations
Development Guidelines
- Follow existing code style and documentation patterns
- Add tests for new model architectures
- Update documentation for new features
- Ensure reproducibility with fixed random seeds
๐บ๏ธ Roadmap
- Transformer-based Fusion: Implement transformer layers for cross-modal attention
- Multi-task Learning: Support for multiple prediction tasks
- Hyperparameter Optimization: Automated hyperparameter tuning
- Model Interpretability: Attention visualization and feature importance
- Distributed Training: Multi-GPU support for large-scale experiments
- Pre-trained Models: Release pre-trained checkpoints for common datasets
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tabulartextmultimodalfusion-0.1.5.tar.gz.
File metadata
- Download URL: tabulartextmultimodalfusion-0.1.5.tar.gz
- Upload date:
- Size: 52.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.8.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
70fa6cbcb69a5e474b1821fcc72846726dae564b3a0f57ab61f42e4e3c18719a
|
|
| MD5 |
bd682e953a6db5ddacb4f3de1d3f5ed1
|
|
| BLAKE2b-256 |
438f86b3a8a1f4a0508fbddf6afcf5b624509b24e0d515f2aea19cf25471c176
|
File details
Details for the file tabulartextmultimodalfusion-0.1.5-py3-none-any.whl.
File metadata
- Download URL: tabulartextmultimodalfusion-0.1.5-py3-none-any.whl
- Upload date:
- Size: 50.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.8.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f9b1734c480d995ea7ac34ad120b1a287aa4c071a884731c0acb8262c37d08c1
|
|
| MD5 |
05f22a8ee8faf9d6de5162d7b922b46a
|
|
| BLAKE2b-256 |
7791599742d119ffaf18732d76e94f63c8e8f19f183af50e458ee406cd60d60b
|