Skip to main content

Causal Deep Learning marketing-mix modeling library with response curves

Project description

DeepCausalMMM

Advanced Marketing Mix Modeling with Causal Inference and Deep Learning

Documentation Open In Colab PyPI version DOI MMM Deep Learning Causal DAG GRU Python PyTorch License

Development history

A substantial part of the design and prototyping was done locally before this repository was published on GitHub. The public commit history therefore reflects integration, hardening, documentation, tests, and packaging of that prior work more than a day-by-day log of initial exploration. You may see bursts of commits (e.g. around releases or doc pushes) and quieter periods in between; that pattern is typical when integrating a working prototype into an open-source layout rather than indicating only the span of weeks visible on GitHub.

Key Features

Architecture

  • Config-Driven: Every setting configurable via config.py
  • GRU-Based Temporal Modeling: Captures complex time-varying effects
  • DAG Learning: Discovers causal relationships between channels
  • Learnable Coefficient Bounds: Channel-specific, data-driven constraints
  • Data-Driven Seasonality: Automatic seasonal decomposition per region

Robust Statistical Methods

  • Huber Loss: Robust to outliers and extreme values
  • Multiple Metrics: RMSE, R², MAE, Trimmed RMSE
  • Advanced Regularization: L1/L2, sparsity, coefficient-specific penalties
  • Gradient Clipping: Parameter-specific clipping for stability

Comprehensive Analysis in Examples Folder

  • Interactive Visualizations: Example html dashboard with some plots and insights
  • Response Curves: Non-linear saturation analysis with Hill equations
  • Budget Optimization: Constrained optimization for optimal channel allocation
  • DMA-Level Contributions: True economic impact calculation at DMA level
  • Channel Effectiveness: Detailed performance analysis

Quick Start

Installation

From PyPI (Recommended)

pip install deepcausalmmm

From GitHub (Development Version)

# Install the latest development version with all fixes
pip install git+https://github.com/adityapt/deepcausalmmm.git

Manual Installation

# Clone repository
git clone https://github.com/adityapt/deepcausalmmm.git
cd deepcausalmmm
pip install -e .

Dependencies Only

Same lower/upper bounds as pyproject.toml [project] dependencies (install libraries without the deepcausalmmm package):

pip install \
  "torch>=2.0" \
  "pandas>=1.5" \
  "numpy>=1.21,<2.0" \
  "plotly>=5.0" \
  "networkx>=2.6" \
  "scikit-learn>=1.0" \
  "scipy>=1.7" \
  "statsmodels>=0.13" \
  "tqdm>=4.60"

IMPORTANT: Version Compatibility

The code examples in this README are for version 1.0.19+

If you installed from PyPI previously (pip install deepcausalmmm), you might have v1.0.18 or below which has a completely different API. The examples below will NOT work for v1.0.18 and below.

For Current PyPI Users (v1.0.18 and below):

# Option 1: Upgrade to latest version from PyPi
pip install --upgrade deepcausalmmm
# OR shorthand
pip install -U deepcausalmmm

# Option 2: Install the development version (from GitHub)
pip install --upgrade git+https://github.com/adityapt/deepcausalmmm.git

# Option 3: Force reinstall if you have conflicts
pip install --upgrade --force-reinstall deepcausalmmm

API Changes in v1.0.19:

  • New: pipeline parameter in ModelTrainer.train()
  • Fixed: Proper data scaling without leakage
  • Fixed: Correct attribution calculation
  • New: ConfigurableDataGenerator.generate_mmm_dataset() method
  • Removed: y_full_for_baseline parameter (data leakage fix)

Basic Usage

import numpy as np
from deepcausalmmm import get_device
from deepcausalmmm.core import get_default_config
from deepcausalmmm.core.trainer import ModelTrainer
from deepcausalmmm.core.data import UnifiedDataPipeline
from deepcausalmmm.utils.data_generator import ConfigurableDataGenerator

# Generate synthetic data for testing
# You can replace this with your own data in the same format
n_regions = 50        # Number of DMAs/regions
n_weeks = 104         # 2 years of weekly data
n_media = 13          # Number of media channels
n_control = 3         # Number of control variables

generator = ConfigurableDataGenerator()
X_media, X_control, y = generator.generate_mmm_dataset(
    n_regions=n_regions,
    n_weeks=n_weeks,
    n_media_channels=n_media,
    n_control_channels=n_control
)

# Expected data format:
# X_media: [n_regions, n_weeks, n_media_channels] - Media inputs
# X_control: [n_regions, n_weeks, n_control_variables] - Control variables
# y: [n_regions, n_weeks] - Target variable (KPI units)

# Get configuration
config = get_default_config()
config['n_epochs'] = 200  # Adjust as needed

# Check device availability
device = get_device()
print(f"Using device: {device}")

# Initialize pipeline
pipeline = UnifiedDataPipeline(config)

# Split data temporally (train/holdout)
train_data, holdout_data = pipeline.temporal_split(X_media, X_control, y)

# Process training data
train_tensors = pipeline.fit_and_transform_training(train_data)
holdout_tensors = pipeline.transform_holdout(holdout_data)

# Create trainer and model
trainer = ModelTrainer(config)
model = trainer.create_model(
    n_media=train_tensors['X_media'].shape[2],
    n_control=train_tensors['X_control'].shape[2],
    n_regions=train_tensors['X_media'].shape[0]
)
trainer.create_optimizer_and_scheduler()

# Train model
results = trainer.train(
    train_tensors['X_media'], train_tensors['X_control'],
    train_tensors['R'], train_tensors['y'],
    holdout_tensors['X_media'], holdout_tensors['X_control'],
    holdout_tensors['R'], holdout_tensors['y'],
    pipeline=pipeline,
    verbose=True
)

# View results
print(f"Training R²: {results['final_train_r2']:.3f}")
print(f"Holdout R²: {results['final_holdout_r2']:.3f}")

Running Examples (Requires Cloning Repository)

The examples/ folder is only available if you clone the repository (not included in pip install):

# Clone the repository first
git clone https://github.com/adityapt/deepcausalmmm.git
cd deepcausalmmm

# Install the package
pip install -e .

# Run the comprehensive dashboard (uses real-world anonymized data)
python examples/dashboard_rmse_optimized.py

# Or run other examples
python examples/example_budget_optimization.py
python examples/example_response_curves.py

Package Import Test

# Verify installation works
from deepcausalmmm import DeepCausalMMM, get_device
from deepcausalmmm.core import get_default_config

print("DeepCausalMMM package imported successfully!")
print(f"Device: {get_device()}")

Project Structure

deepcausalmmm/                      # Project root
├── pyproject.toml                  # Package configuration and dependencies
├── README.md                       # This documentation
├── LICENSE                         # MIT License
├── CHANGELOG.md                    # Version history and changes
├── RELEASE_NOTES_1.0.19.md         # Latest release notes
├── CONTRIBUTING.md                 # Development guidelines
├── CODE_OF_CONDUCT.md              # Code of conduct
├── CITATION.cff                    # Citation metadata for Zenodo/GitHub
├── Makefile                        # Build and development tasks
├── MANIFEST.in                     # Package manifest for distribution
│
├── deepcausalmmm/                  # Main package directory
│   ├── __init__.py                 # Package initialization and exports
│   ├── cli.py                      # Command-line interface
│   ├── exceptions.py               # Custom exception classes
│   │
│   ├── core/                       # Core model components
│   │   ├── __init__.py            # Core module initialization
│   │   ├── config.py              # Optimized configuration parameters
│   │   ├── unified_model.py       # Main DeepCausalMMM model architecture
│   │   ├── trainer.py             # ModelTrainer class for training
│   │   ├── data.py                # UnifiedDataPipeline for data processing
│   │   ├── scaling.py             # SimpleGlobalScaler for data normalization
│   │   ├── seasonality.py         # Seasonal decomposition utilities
│   │   ├── dag_model.py           # DAG learning and causal inference
│   │   ├── inference.py           # Model inference and prediction
│   │   ├── train_model.py         # Training functions and utilities
│   │   └── visualization.py       # Core visualization components
│   │
│   ├── postprocess/                # Analysis and post-processing
│   │   ├── __init__.py            # Postprocess module initialization
│   │   ├── analysis.py            # Statistical analysis utilities
│   │   ├── comprehensive_analysis.py  # Comprehensive analyzer
│   │   ├── response_curves.py     # Non-linear response curve fitting (Hill equations)
│   │   ├── optimization.py        # Budget optimization with response curves
│   │   ├── optimization_utils.py  # Optimization utility functions
│   │   └── dag_postprocess.py     # DAG post-processing and analysis
│   │
│   └── utils/                      # Utility functions
│       ├── __init__.py            # Utils module initialization
│       ├── device.py              # GPU/CPU device detection
│       └── data_generator.py      # Synthetic data generation (ConfigurableDataGenerator)
│
├── examples/                       # Example scripts and notebooks
│   ├── quickstart.ipynb           # Short API intro (Google Colab badge)
│   ├── mmm_three_way_benchmark.ipynb  # PyMC / Meridian / DCM / Ridge benchmark (optional; long runtimes)
│   ├── dashboard_rmse_optimized.py # Comprehensive dashboard with 14+ visualizations
│   ├── example_response_curves.py  # Response curve fitting examples
│   ├── example_budget_optimization.py  # Budget optimization workflow
│   └── data/                      # Example data directory
│       └── MMM Data.csv           # Anonymized real-world MMM dataset
│
├── tests/                          # Test suite
│   ├── __init__.py                # Test package initialization
│   ├── unit/                      # Unit tests
│   │   ├── __init__.py
│   │   ├── test_config.py         # Configuration tests
│   │   ├── test_model.py          # Model architecture tests
│   │   ├── test_scaling.py        # Data scaling tests
│   │   ├── test_response_curves.py # Response curve fitting tests
│   │   └── test_inference.py      # InferenceManager.predict / forward contract
│   └── integration/               # Integration tests
│       ├── __init__.py
│       ├── test_end_to_end.py     # End-to-end integration tests
│       └── test_dashboard_rmse_optimized.py  # Dashboard script + real CSV data path
│
├── docs/                           # Documentation
│   ├── Makefile                   # Documentation build tasks
│   ├── make.bat                   # Windows documentation build
│   ├── requirements.txt           # Documentation dependencies
│   └── source/                    # Sphinx documentation source
│       ├── conf.py               # Sphinx configuration
│       ├── index.rst             # Documentation index
│       ├── installation.rst      # Installation guide
│       ├── quickstart.rst        # Quick start guide
│       ├── contributing.rst      # Contributing guide
│       ├── api/                  # API documentation
│       │   ├── index.rst
│       │   ├── core.rst
│       │   ├── data.rst
│       │   ├── trainer.rst
│       │   ├── inference.rst
│       │   ├── analysis.rst
│       │   ├── response_curves.rst # Response curves API
│       │   ├── optimization.rst    # Budget optimization API
│       │   ├── cli.rst             # Command-line interface
│       │   ├── visualization.rst   # core.visualization (VisualizationManager)
│       │   ├── utils.rst
│       │   └── exceptions.rst
│       ├── examples/             # Example documentation
│       │   ├── index.rst
│       │   ├── retail_mmm.rst
│       │   └── multi_region.rst
│       └── tutorials/            # Tutorial documentation
│           └── index.rst
│
└── JOSS/                           # Journal of Open Source Software submission
    ├── paper.md                   # JOSS paper manuscript
    ├── paper.bib                  # Bibliography
    ├── figure_dag_professional.png # DAG visualization figure
    └── figure_response_curve_simple.png # Response curve figure

Examples Folder Dashboard Has

  1. Performance Metrics: Training vs Holdout comparison
  2. Actual vs Predicted: Time series visualization
  3. Holdout Scatter: Generalization assessment
  4. Economic Contributions: Total KPI per channel
  5. Contribution Breakdown: Donut chart with percentages
  6. Waterfall Analysis: Decomposed contribution flow
  7. Channel Effectiveness: Coefficient distributions
  8. DAG Heatmap: Adjacency matrix visualization
  9. Stacked Contributions: Time-based channel impact
  10. Individual Channels: Detailed channel analysis
  11. Scaled Data: Normalized time series
  12. Response Curves: Non-linear response curves (diminishing returns analysis) with Hill equations

Configuration

Key configuration parameters:

{
    # Model Architecture
    'hidden_dim': 320,           # Optimal hidden dimension
    'dropout': 0.08,             # Proven stable dropout
    'gru_layers': 1,             # Single layer for stability
    
    # Training Parameters  
    'n_epochs': 6500,            # Optimal convergence epochs
    'learning_rate': 0.009,      # Fine-tuned learning rate
    'temporal_regularization': 0.04,  # Proven regularization
    
    # Loss Function
    'use_huber_loss': True,      # Robust to outliers
    'huber_delta': 0.3,          # Optimal delta value
    
    # Data Processing
    'holdout_ratio': 0.08,       # Optimal train/test split
    'burn_in_weeks': 6,          # Stabilization period
}

More Details

Learnable Parameters

  • Media Coefficient Bounds: F.softplus(coeff_max_raw) * torch.sigmoid(media_coeffs_raw)
  • Control Coefficients: Unbounded with gradient clipping
  • Trend Damping: Disabled for now, will be enabled in future releases
  • Baseline Components: Non-negative via F.softplus
  • Seasonal Coefficient: Learnable seasonal contribution

Data Processing

  • Linear Scaling: Dependent variable scaled by regional mean (y/y_mean) for balanced training
  • SOV Scaling: Share-of-voice normalization for media channels
  • Z-Score Normalization: For control variables (weather, events, etc.)
  • Min-Max Seasonality: Regional seasonal scaling (0-1) using seasonal_decompose
  • DMA-Level Processing: Contributions calculated per region
  • Attribution Priors: Media contribution regularization with dynamic loss scaling
  • Data-Driven Hill Initialization: Hill parameters initialized from channel-specific SOV percentiles

Regularization Strategy

  • Coefficient L2: Channel-specific regularization
  • Sparsity Control: GRU parameter sparsity
  • DAG Regularization: Acyclicity constraints
  • Gradient Clipping: Parameter-specific clipping

Response Curves

  • Hill Saturation Modeling: Non-linear response curves with Hill equations
  • Data-Driven Initialization: Hill g parameter initialized from channel-specific SOV 60th percentile
  • Automatic Curve Fitting: Fits S-shaped saturation curves to channel data
  • National-Level Aggregation: Aggregates DMA-week data to national weekly level
  • Linear Scaling: Direct scaling with prediction_scale × y_mean for accurate attribution
  • Interactive Visualizations: Plotly-based interactive response curve plots
  • Performance Metrics: R², slope, and saturation point for each channel
import pandas as pd
import numpy as np
from deepcausalmmm.postprocess import ResponseCurveFit

# After training your model, prepare channel data for response curve fitting
# The data should have 'week_monday', 'spend', 'impressions', and 'predicted' columns

# Example: Create channel data from your model results
# Replace this with actual data extraction from your trained model
n_weeks = 104
channel_data = pd.DataFrame({
    'week_monday': pd.date_range('2024-01-01', periods=n_weeks, freq='W'),
    'spend': np.random.uniform(10000, 50000, n_weeks),       # Replace with actual spend
    'impressions': np.random.uniform(100000, 500000, n_weeks),  # Replace with actual impressions
    'predicted': np.random.uniform(1000, 5000, n_weeks)      # Replace with model predictions
})

# Fit response curves to channel data
fitter = ResponseCurveFit(
    data=channel_data,
    model_level='Overall',
    date_col='week_monday'
)

# Fit the curve and generate visualization
fitter.fit(
    x_label='Impressions',
    y_label='Predicted KPI Units',
    title='Channel Response Curve',
    save_figure=True,
    output_path='response_curve.html'
)

print(f"Slope: {fitter.slope:.3f}, Saturation: {fitter.saturation:.0f}, R²: {fitter.r_2:.3f}")

Budget Optimization

  • Constrained Optimization: Find optimal budget allocation across channels
  • Multiple Methods: SLSQP (default), trust-constr, differential evolution, hybrid
  • Hill Equation Integration: Uses fitted response curves for saturation modeling
  • Channel Constraints: Set min/max spend limits based on business requirements
  • Scenario Comparison: Compare current vs optimal allocations
  • ROI Maximization: Maximize predicted response subject to budget and constraints
import pandas as pd
from deepcausalmmm import optimize_budget_from_curves

# After training your model and fitting response curves...
# Create a DataFrame with fitted curve parameters for each channel
# Replace this with actual fitted parameters from your response curve fitting

fitted_curves_df = pd.DataFrame({
    'channel': ['TV', 'Search', 'Social', 'Display', 'Radio'],
    'top': [2.5, 3.0, 2.2, 1.8, 2.0],              # Hill parameter (slope at inflection)
    'bottom': [0.0, 0.0, 0.0, 0.0, 0.0],           # Minimum response
    'saturation': [500000, 300000, 200000, 150000, 400000],  # Saturation point (impressions)
    'slope': [0.002, 0.003, 0.004, 0.002, 0.001]   # Initial slope
})

# Optimize budget allocation
result = optimize_budget_from_curves(
    budget=1_000_000,
    curve_params=fitted_curves_df,
    num_weeks=52,
    constraints={
        'TV': {'lower': 100000, 'upper': 600000},
        'Search': {'lower': 150000, 'upper': 500000},
        'Social': {'lower': 50000, 'upper': 300000}
    },
    method='SLSQP'
)

# View results
if result.success:
    print(f"Optimal Allocation: {result.allocation}")
    print(f"Predicted Response: {result.predicted_response:,.0f}")
    print(result.by_channel)

Example Output:

Optimal Allocation: {'TV': 100000, 'Search': 420000, 'Social': 300000, ...}
Predicted Response: 627,788

Detailed Metrics:
  channel  total_spend  weekly_spend  roi  spend_pct  response_pct  saturation_pct
   Search      420,000      8,076.92  0.56      42.0%        37.8%           323%
   Social      300,000      5,769.23  0.73      30.0%        34.8%           288%
       TV      100,000      1,923.08  0.13      10.0%         2.1%            64%

See examples/example_budget_optimization.py for complete workflow and tips.

Benchmarks

Example Validation (190 regions, 109 weeks, 13 channels, 7 controls):

  • Training R²: 0.950 | Holdout R²: 0.842
  • Performance Gap: 10.8% (indicating good generalization)
  • Generalization Gap: 10.8% (reasonable out-of-sample performance)
  • Temporal Split: Default holdout_ratio = 0.12 in config—about 96 training weeks and 13 holdout weeks on 109 observed weeks (burn-in padding in the pipeline may change logged lengths slightly)

Development

Requirements

  • Python 3.9+
  • PyTorch 2.0+
  • pandas 1.5+
  • numpy 1.21+ (package metadata caps below 2.0)
  • scipy 1.7+
  • plotly 5.0+
  • NetworkX 2.6+
  • statsmodels 0.13+
  • scikit-learn 1.0+
  • tqdm 4.60+

Testing

python -m pytest tests/

Contributing

See CONTRIBUTING.md for development guidelines.

License

MIT License - see LICENSE file.

Support

  • Documentation: Includes a comprehensive README with examples
  • Issues: Use GitHub issues for bug reports and feature requests

Documentation

Roadmap

Version 1.0.22 (planned)

  • NOTEARS DAG Learning: Full implementation of the NOTEARS (DAGs with NO TEARS) continuous optimization method for discovering arbitrary DAG structures
  • Enhanced Causal Discovery: Move beyond upper triangular constraints to learn more flexible causal relationships between marketing channels

Citation

If you use DeepCausalMMM in your work, please cite:

@article{tirumala2025deepcausalmmm,
  title={DeepCausalMMM: A Deep Learning Framework for Marketing Mix Modeling with Causal Inference},
  author={Puttaparthi Tirumala, Aditya},
  journal={arXiv preprint arXiv:2510.13087},
  year={2025}
}

Or click the "Cite this repository" button on GitHub for other citation formats (APA, Chicago, MLA).

Quick Links

  • Main Dashboard: dashboard_rmse_optimized.py - Complete analysis pipeline
  • Budget Optimization: examples/example_budget_optimization.py - End-to-end optimization workflow
  • Core Model: deepcausalmmm/core/unified_model.py - DeepCausalMMM architecture
  • Configuration: deepcausalmmm/core/config.py - All tunable parameters
  • Data Pipeline: deepcausalmmm/core/data.py - Data processing and scaling

DeepCausalMMM - Where Deep Learning meets Causal Inference for Superior Marketing Mix Modeling

arXiv preprint - https://www.arxiv.org/abs/2510.13087

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deepcausalmmm-1.0.20.tar.gz (1.8 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

deepcausalmmm-1.0.20-py3-none-any.whl (107.8 kB view details)

Uploaded Python 3

File details

Details for the file deepcausalmmm-1.0.20.tar.gz.

File metadata

  • Download URL: deepcausalmmm-1.0.20.tar.gz
  • Upload date:
  • Size: 1.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for deepcausalmmm-1.0.20.tar.gz
Algorithm Hash digest
SHA256 ec76c24130170553cf01e6df444328168e2d44e33a464da7b609bfea06c09749
MD5 197f610680a7679e59575d6c05feb8d5
BLAKE2b-256 34437297d89dbc2ce65643c380158726590f35b088c56ff6d698cf21733b51c6

See more details on using hashes here.

File details

Details for the file deepcausalmmm-1.0.20-py3-none-any.whl.

File metadata

  • Download URL: deepcausalmmm-1.0.20-py3-none-any.whl
  • Upload date:
  • Size: 107.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for deepcausalmmm-1.0.20-py3-none-any.whl
Algorithm Hash digest
SHA256 1e7d67cfce41a95bff1a8c1f051585adfa44cc403cc62c399ba0823a94ed9480
MD5 9e4f56bf8d9291c6a31048dbc60ddf86
BLAKE2b-256 c9ee93696ed9284ae61dc9feb906953f985b7d645aeda97dddbdf01a5442ca69

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page