Causal Deep Learning marketing-mix modeling library with response curves
Project description
DeepCausalMMM
Advanced Marketing Mix Modeling with Causal Inference and Deep Learning
Development history
A substantial part of the design and prototyping was done locally before this repository was published on GitHub. The public commit history therefore reflects integration, hardening, documentation, tests, and packaging of that prior work more than a day-by-day log of initial exploration. You may see bursts of commits (e.g. around releases or doc pushes) and quieter periods in between; that pattern is typical when integrating a working prototype into an open-source layout rather than indicating only the span of weeks visible on GitHub.
Key Features
Architecture
- Config-Driven: Every setting configurable via
config.py - GRU-Based Temporal Modeling: Captures complex time-varying effects
- DAG Learning: Discovers causal relationships between channels
- Learnable Coefficient Bounds: Channel-specific, data-driven constraints
- Data-Driven Seasonality: Automatic seasonal decomposition per region
Robust Statistical Methods
- Huber Loss: Robust to outliers and extreme values
- Multiple Metrics: RMSE, R², MAE, Trimmed RMSE
- Advanced Regularization: L1/L2, sparsity, coefficient-specific penalties
- Gradient Clipping: Parameter-specific clipping for stability
Comprehensive Analysis in Examples Folder
- Interactive Visualizations: Example html dashboard with some plots and insights
- Response Curves: Non-linear saturation analysis with Hill equations
- Budget Optimization: Constrained optimization for optimal channel allocation
- DMA-Level Contributions: True economic impact calculation at DMA level
- Channel Effectiveness: Detailed performance analysis
Quick Start
Installation
From PyPI (Recommended)
pip install deepcausalmmm
From GitHub (Development Version)
# Install the latest development version with all fixes
pip install git+https://github.com/adityapt/deepcausalmmm.git
Manual Installation
# Clone repository
git clone https://github.com/adityapt/deepcausalmmm.git
cd deepcausalmmm
pip install -e .
Dependencies Only
Same lower/upper bounds as pyproject.toml [project] dependencies (install libraries without the deepcausalmmm package):
pip install \
"torch>=2.0" \
"pandas>=1.5" \
"numpy>=1.21,<2.0" \
"plotly>=5.0" \
"networkx>=2.6" \
"scikit-learn>=1.0" \
"scipy>=1.7" \
"statsmodels>=0.13" \
"tqdm>=4.60"
IMPORTANT: Version Compatibility
The code examples in this README are for version 1.0.19+
If you installed from PyPI previously (pip install deepcausalmmm), you might have v1.0.18 or below which has a completely different API. The examples below will NOT work for v1.0.18 and below.
For Current PyPI Users (v1.0.18 and below):
# Option 1: Upgrade to latest version from PyPi
pip install --upgrade deepcausalmmm
# OR shorthand
pip install -U deepcausalmmm
# Option 2: Install the development version (from GitHub)
pip install --upgrade git+https://github.com/adityapt/deepcausalmmm.git
# Option 3: Force reinstall if you have conflicts
pip install --upgrade --force-reinstall deepcausalmmm
API Changes in v1.0.19:
- New:
pipelineparameter inModelTrainer.train() - Fixed: Proper data scaling without leakage
- Fixed: Correct attribution calculation
- New:
ConfigurableDataGenerator.generate_mmm_dataset()method - Removed:
y_full_for_baselineparameter (data leakage fix)
Basic Usage
import numpy as np
from deepcausalmmm import get_device
from deepcausalmmm.core import get_default_config
from deepcausalmmm.core.trainer import ModelTrainer
from deepcausalmmm.core.data import UnifiedDataPipeline
from deepcausalmmm.utils.data_generator import ConfigurableDataGenerator
# Generate synthetic data for testing
# You can replace this with your own data in the same format
n_regions = 50 # Number of DMAs/regions
n_weeks = 104 # 2 years of weekly data
n_media = 13 # Number of media channels
n_control = 3 # Number of control variables
generator = ConfigurableDataGenerator()
X_media, X_control, y = generator.generate_mmm_dataset(
n_regions=n_regions,
n_weeks=n_weeks,
n_media_channels=n_media,
n_control_channels=n_control
)
# Expected data format:
# X_media: [n_regions, n_weeks, n_media_channels] - Media inputs
# X_control: [n_regions, n_weeks, n_control_variables] - Control variables
# y: [n_regions, n_weeks] - Target variable (KPI units)
# Get configuration
config = get_default_config()
config['n_epochs'] = 200 # Adjust as needed
# Check device availability
device = get_device()
print(f"Using device: {device}")
# Initialize pipeline
pipeline = UnifiedDataPipeline(config)
# Split data temporally (train/holdout)
train_data, holdout_data = pipeline.temporal_split(X_media, X_control, y)
# Process training data
train_tensors = pipeline.fit_and_transform_training(train_data)
holdout_tensors = pipeline.transform_holdout(holdout_data)
# Create trainer and model
trainer = ModelTrainer(config)
model = trainer.create_model(
n_media=train_tensors['X_media'].shape[2],
n_control=train_tensors['X_control'].shape[2],
n_regions=train_tensors['X_media'].shape[0]
)
trainer.create_optimizer_and_scheduler()
# Train model
results = trainer.train(
train_tensors['X_media'], train_tensors['X_control'],
train_tensors['R'], train_tensors['y'],
holdout_tensors['X_media'], holdout_tensors['X_control'],
holdout_tensors['R'], holdout_tensors['y'],
pipeline=pipeline,
verbose=True
)
# View results
print(f"Training R²: {results['final_train_r2']:.3f}")
print(f"Holdout R²: {results['final_holdout_r2']:.3f}")
Running Examples (Requires Cloning Repository)
The examples/ folder is only available if you clone the repository (not included in pip install):
# Clone the repository first
git clone https://github.com/adityapt/deepcausalmmm.git
cd deepcausalmmm
# Install the package
pip install -e .
# Run the comprehensive dashboard (uses real-world anonymized data)
python examples/dashboard_rmse_optimized.py
# Or run other examples
python examples/example_budget_optimization.py
python examples/example_response_curves.py
Package Import Test
# Verify installation works
from deepcausalmmm import DeepCausalMMM, get_device
from deepcausalmmm.core import get_default_config
print("DeepCausalMMM package imported successfully!")
print(f"Device: {get_device()}")
Project Structure
deepcausalmmm/ # Project root
├── pyproject.toml # Package configuration and dependencies
├── README.md # This documentation
├── LICENSE # MIT License
├── CHANGELOG.md # Version history and changes
├── RELEASE_NOTES_1.0.19.md # Latest release notes
├── CONTRIBUTING.md # Development guidelines
├── CODE_OF_CONDUCT.md # Code of conduct
├── CITATION.cff # Citation metadata for Zenodo/GitHub
├── Makefile # Build and development tasks
├── MANIFEST.in # Package manifest for distribution
│
├── deepcausalmmm/ # Main package directory
│ ├── __init__.py # Package initialization and exports
│ ├── cli.py # Command-line interface
│ ├── exceptions.py # Custom exception classes
│ │
│ ├── core/ # Core model components
│ │ ├── __init__.py # Core module initialization
│ │ ├── config.py # Optimized configuration parameters
│ │ ├── unified_model.py # Main DeepCausalMMM model architecture
│ │ ├── trainer.py # ModelTrainer class for training
│ │ ├── data.py # UnifiedDataPipeline for data processing
│ │ ├── scaling.py # SimpleGlobalScaler for data normalization
│ │ ├── seasonality.py # Seasonal decomposition utilities
│ │ ├── dag_model.py # DAG learning and causal inference
│ │ ├── inference.py # Model inference and prediction
│ │ ├── train_model.py # Training functions and utilities
│ │ └── visualization.py # Core visualization components
│ │
│ ├── postprocess/ # Analysis and post-processing
│ │ ├── __init__.py # Postprocess module initialization
│ │ ├── analysis.py # Statistical analysis utilities
│ │ ├── comprehensive_analysis.py # Comprehensive analyzer
│ │ ├── response_curves.py # Non-linear response curve fitting (Hill equations)
│ │ ├── optimization.py # Budget optimization with response curves
│ │ ├── optimization_utils.py # Optimization utility functions
│ │ └── dag_postprocess.py # DAG post-processing and analysis
│ │
│ └── utils/ # Utility functions
│ ├── __init__.py # Utils module initialization
│ ├── device.py # GPU/CPU device detection
│ └── data_generator.py # Synthetic data generation (ConfigurableDataGenerator)
│
├── examples/ # Example scripts and notebooks
│ ├── quickstart.ipynb # Short API intro (Google Colab badge)
│ ├── mmm_three_way_benchmark.ipynb # PyMC / Meridian / DCM / Ridge benchmark (optional; long runtimes)
│ ├── dashboard_rmse_optimized.py # Comprehensive dashboard with 14+ visualizations
│ ├── example_response_curves.py # Response curve fitting examples
│ ├── example_budget_optimization.py # Budget optimization workflow
│ └── data/ # Example data directory
│ └── MMM Data.csv # Anonymized real-world MMM dataset
│
├── tests/ # Test suite
│ ├── __init__.py # Test package initialization
│ ├── unit/ # Unit tests
│ │ ├── __init__.py
│ │ ├── test_config.py # Configuration tests
│ │ ├── test_model.py # Model architecture tests
│ │ ├── test_scaling.py # Data scaling tests
│ │ ├── test_response_curves.py # Response curve fitting tests
│ │ └── test_inference.py # InferenceManager.predict / forward contract
│ └── integration/ # Integration tests
│ ├── __init__.py
│ ├── test_end_to_end.py # End-to-end integration tests
│ └── test_dashboard_rmse_optimized.py # Dashboard script + real CSV data path
│
├── docs/ # Documentation
│ ├── Makefile # Documentation build tasks
│ ├── make.bat # Windows documentation build
│ ├── requirements.txt # Documentation dependencies
│ └── source/ # Sphinx documentation source
│ ├── conf.py # Sphinx configuration
│ ├── index.rst # Documentation index
│ ├── installation.rst # Installation guide
│ ├── quickstart.rst # Quick start guide
│ ├── contributing.rst # Contributing guide
│ ├── api/ # API documentation
│ │ ├── index.rst
│ │ ├── core.rst
│ │ ├── data.rst
│ │ ├── trainer.rst
│ │ ├── inference.rst
│ │ ├── analysis.rst
│ │ ├── response_curves.rst # Response curves API
│ │ ├── optimization.rst # Budget optimization API
│ │ ├── cli.rst # Command-line interface
│ │ ├── visualization.rst # core.visualization (VisualizationManager)
│ │ ├── utils.rst
│ │ └── exceptions.rst
│ ├── examples/ # Example documentation
│ │ ├── index.rst
│ │ ├── retail_mmm.rst
│ │ └── multi_region.rst
│ └── tutorials/ # Tutorial documentation
│ └── index.rst
│
└── JOSS/ # Journal of Open Source Software submission
├── paper.md # JOSS paper manuscript
├── paper.bib # Bibliography
├── figure_dag_professional.png # DAG visualization figure
└── figure_response_curve_simple.png # Response curve figure
Examples Folder Dashboard Has
- Performance Metrics: Training vs Holdout comparison
- Actual vs Predicted: Time series visualization
- Holdout Scatter: Generalization assessment
- Economic Contributions: Total KPI per channel
- Contribution Breakdown: Donut chart with percentages
- Waterfall Analysis: Decomposed contribution flow
- Channel Effectiveness: Coefficient distributions
- DAG Heatmap: Adjacency matrix visualization
- Stacked Contributions: Time-based channel impact
- Individual Channels: Detailed channel analysis
- Scaled Data: Normalized time series
- Response Curves: Non-linear response curves (diminishing returns analysis) with Hill equations
Configuration
Key configuration parameters:
{
# Model Architecture
'hidden_dim': 320, # Optimal hidden dimension
'dropout': 0.08, # Proven stable dropout
'gru_layers': 1, # Single layer for stability
# Training Parameters
'n_epochs': 6500, # Optimal convergence epochs
'learning_rate': 0.009, # Fine-tuned learning rate
'temporal_regularization': 0.04, # Proven regularization
# Loss Function
'use_huber_loss': True, # Robust to outliers
'huber_delta': 0.3, # Optimal delta value
# Data Processing
'holdout_ratio': 0.08, # Optimal train/test split
'burn_in_weeks': 6, # Stabilization period
}
More Details
Learnable Parameters
- Media Coefficient Bounds:
F.softplus(coeff_max_raw) * torch.sigmoid(media_coeffs_raw) - Control Coefficients: Unbounded with gradient clipping
- Trend Damping: Disabled for now, will be enabled in future releases
- Baseline Components: Non-negative via
F.softplus - Seasonal Coefficient: Learnable seasonal contribution
Data Processing
- Linear Scaling: Dependent variable scaled by regional mean (y/y_mean) for balanced training
- SOV Scaling: Share-of-voice normalization for media channels
- Z-Score Normalization: For control variables (weather, events, etc.)
- Min-Max Seasonality: Regional seasonal scaling (0-1) using
seasonal_decompose - DMA-Level Processing: Contributions calculated per region
- Attribution Priors: Media contribution regularization with dynamic loss scaling
- Data-Driven Hill Initialization: Hill parameters initialized from channel-specific SOV percentiles
Regularization Strategy
- Coefficient L2: Channel-specific regularization
- Sparsity Control: GRU parameter sparsity
- DAG Regularization: Acyclicity constraints
- Gradient Clipping: Parameter-specific clipping
Response Curves
- Hill Saturation Modeling: Non-linear response curves with Hill equations
- Data-Driven Initialization: Hill
gparameter initialized from channel-specific SOV 60th percentile - Automatic Curve Fitting: Fits S-shaped saturation curves to channel data
- National-Level Aggregation: Aggregates DMA-week data to national weekly level
- Linear Scaling: Direct scaling with prediction_scale × y_mean for accurate attribution
- Interactive Visualizations: Plotly-based interactive response curve plots
- Performance Metrics: R², slope, and saturation point for each channel
import pandas as pd
import numpy as np
from deepcausalmmm.postprocess import ResponseCurveFit
# After training your model, prepare channel data for response curve fitting
# The data should have 'week_monday', 'spend', 'impressions', and 'predicted' columns
# Example: Create channel data from your model results
# Replace this with actual data extraction from your trained model
n_weeks = 104
channel_data = pd.DataFrame({
'week_monday': pd.date_range('2024-01-01', periods=n_weeks, freq='W'),
'spend': np.random.uniform(10000, 50000, n_weeks), # Replace with actual spend
'impressions': np.random.uniform(100000, 500000, n_weeks), # Replace with actual impressions
'predicted': np.random.uniform(1000, 5000, n_weeks) # Replace with model predictions
})
# Fit response curves to channel data
fitter = ResponseCurveFit(
data=channel_data,
model_level='Overall',
date_col='week_monday'
)
# Fit the curve and generate visualization
fitter.fit(
x_label='Impressions',
y_label='Predicted KPI Units',
title='Channel Response Curve',
save_figure=True,
output_path='response_curve.html'
)
print(f"Slope: {fitter.slope:.3f}, Saturation: {fitter.saturation:.0f}, R²: {fitter.r_2:.3f}")
Budget Optimization
- Constrained Optimization: Find optimal budget allocation across channels
- Multiple Methods: SLSQP (default), trust-constr, differential evolution, hybrid
- Hill Equation Integration: Uses fitted response curves for saturation modeling
- Channel Constraints: Set min/max spend limits based on business requirements
- Scenario Comparison: Compare current vs optimal allocations
- ROI Maximization: Maximize predicted response subject to budget and constraints
import pandas as pd
from deepcausalmmm import optimize_budget_from_curves
# After training your model and fitting response curves...
# Create a DataFrame with fitted curve parameters for each channel
# Replace this with actual fitted parameters from your response curve fitting
fitted_curves_df = pd.DataFrame({
'channel': ['TV', 'Search', 'Social', 'Display', 'Radio'],
'top': [2.5, 3.0, 2.2, 1.8, 2.0], # Hill parameter (slope at inflection)
'bottom': [0.0, 0.0, 0.0, 0.0, 0.0], # Minimum response
'saturation': [500000, 300000, 200000, 150000, 400000], # Saturation point (impressions)
'slope': [0.002, 0.003, 0.004, 0.002, 0.001] # Initial slope
})
# Optimize budget allocation
result = optimize_budget_from_curves(
budget=1_000_000,
curve_params=fitted_curves_df,
num_weeks=52,
constraints={
'TV': {'lower': 100000, 'upper': 600000},
'Search': {'lower': 150000, 'upper': 500000},
'Social': {'lower': 50000, 'upper': 300000}
},
method='SLSQP'
)
# View results
if result.success:
print(f"Optimal Allocation: {result.allocation}")
print(f"Predicted Response: {result.predicted_response:,.0f}")
print(result.by_channel)
Example Output:
Optimal Allocation: {'TV': 100000, 'Search': 420000, 'Social': 300000, ...}
Predicted Response: 627,788
Detailed Metrics:
channel total_spend weekly_spend roi spend_pct response_pct saturation_pct
Search 420,000 8,076.92 0.56 42.0% 37.8% 323%
Social 300,000 5,769.23 0.73 30.0% 34.8% 288%
TV 100,000 1,923.08 0.13 10.0% 2.1% 64%
See examples/example_budget_optimization.py for complete workflow and tips.
Benchmarks
Example Validation (190 regions, 109 weeks, 13 channels, 7 controls):
- Training R²: 0.950 | Holdout R²: 0.842
- Performance Gap: 10.8% (indicating good generalization)
- Generalization Gap: 10.8% (reasonable out-of-sample performance)
- Temporal Split: Default
holdout_ratio = 0.12in config—about 96 training weeks and 13 holdout weeks on 109 observed weeks (burn-in padding in the pipeline may change logged lengths slightly)
Development
Requirements
- Python 3.9+
- PyTorch 2.0+
- pandas 1.5+
- numpy 1.21+ (package metadata caps below 2.0)
- scipy 1.7+
- plotly 5.0+
- NetworkX 2.6+
- statsmodels 0.13+
- scikit-learn 1.0+
- tqdm 4.60+
Testing
python -m pytest tests/
Contributing
See CONTRIBUTING.md for development guidelines.
License
MIT License - see LICENSE file.
Support
- Documentation: Includes a comprehensive README with examples
- Issues: Use GitHub issues for bug reports and feature requests
Documentation
- Full Documentation: deepcausalmmm.readthedocs.io
- Quick Start Guide: Installation & Usage
- API Reference: Complete API Documentation
- Tutorials: Step-by-step Guides
- Examples: Practical Use Cases
Roadmap
Version 1.0.22 (planned)
- NOTEARS DAG Learning: Full implementation of the NOTEARS (DAGs with NO TEARS) continuous optimization method for discovering arbitrary DAG structures
- Enhanced Causal Discovery: Move beyond upper triangular constraints to learn more flexible causal relationships between marketing channels
Citation
If you use DeepCausalMMM in your work, please cite:
@article{tirumala2025deepcausalmmm,
title={DeepCausalMMM: A Deep Learning Framework for Marketing Mix Modeling with Causal Inference},
author={Puttaparthi Tirumala, Aditya},
journal={arXiv preprint arXiv:2510.13087},
year={2025}
}
Or click the "Cite this repository" button on GitHub for other citation formats (APA, Chicago, MLA).
Quick Links
- Main Dashboard:
dashboard_rmse_optimized.py- Complete analysis pipeline - Budget Optimization:
examples/example_budget_optimization.py- End-to-end optimization workflow - Core Model:
deepcausalmmm/core/unified_model.py- DeepCausalMMM architecture - Configuration:
deepcausalmmm/core/config.py- All tunable parameters - Data Pipeline:
deepcausalmmm/core/data.py- Data processing and scaling
DeepCausalMMM - Where Deep Learning meets Causal Inference for Superior Marketing Mix Modeling
arXiv preprint - https://www.arxiv.org/abs/2510.13087
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file deepcausalmmm-1.0.20.tar.gz.
File metadata
- Download URL: deepcausalmmm-1.0.20.tar.gz
- Upload date:
- Size: 1.8 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ec76c24130170553cf01e6df444328168e2d44e33a464da7b609bfea06c09749
|
|
| MD5 |
197f610680a7679e59575d6c05feb8d5
|
|
| BLAKE2b-256 |
34437297d89dbc2ce65643c380158726590f35b088c56ff6d698cf21733b51c6
|
File details
Details for the file deepcausalmmm-1.0.20-py3-none-any.whl.
File metadata
- Download URL: deepcausalmmm-1.0.20-py3-none-any.whl
- Upload date:
- Size: 107.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1e7d67cfce41a95bff1a8c1f051585adfa44cc403cc62c399ba0823a94ed9480
|
|
| MD5 |
9e4f56bf8d9291c6a31048dbc60ddf86
|
|
| BLAKE2b-256 |
c9ee93696ed9284ae61dc9feb906953f985b7d645aeda97dddbdf01a5442ca69
|