Skip to main content

Tree-based Market Mix Modeling with SHAP attribution

Project description

TreeMMM

Tree-based Market Mix Modeling with SHAP Attribution

Market Mix Modeling that finds what you didn't think to look for.

TreeMMM is a pip-installable Python package that uses gradient-boosted trees (LightGBM, XGBoost, CatBoost) paired with SHAP-based attribution to decompose commercial outcomes into promotional lever contributions. Unlike regression-based MMM tools, TreeMMM automatically discovers non-linear response functions, channel interactions, and heterogeneous customer sensitivity — without requiring the analyst to pre-specify functional forms.

Installation

# Core package (LightGBM + SHAP)
pip install treemmm

# With XGBoost support
pip install treemmm[xgboost]

# With PowerPoint reporting
pip install treemmm[reporting]

# With Jupyter widgets
pip install treemmm[ui]

# Everything
pip install treemmm[all]

# Development
pip install treemmm[dev]

Quickstart

Python API

import treemmm
from treemmm.core.config import ColumnSpec, RunConfig

config = RunConfig(
    columns=ColumnSpec(
        customer_id="hcp_id",
        time_col="month",
        outcome_col="new_patients",
        promo_vars=["rep_visits", "digital", "peer_programs", "samples"],
        control_vars=["seasonality", "market_index"],
    ),
    objective="auto",  # Auto-detects distribution
)

result = treemmm.run(df, config, output_dir="output/")
print(result.summary())

CLI

# Run pipeline on a CSV
treemmm run data.csv \
    --customer-id hcp_id \
    --time-col month \
    --outcome-col new_patients \
    --promo-vars "rep_visits,digital,peer_programs,samples" \
    --control-vars "seasonality,market_index" \
    --objective auto

# Generate a demo dataset
treemmm demo pharma --n-customers 500 --n-periods 24

# Run the benchmark (TreeMMM vs GLMM)
treemmm benchmark --n-customers 200 --n-periods 12

Jupyter Notebook

from treemmm.ui.notebook_runner import NotebookRunner
from treemmm.core.config import ColumnSpec, RunConfig

runner = NotebookRunner(df, config)
result = runner.run()

runner.show_attribution()   # Bar chart + table
runner.show_performance()   # R²/WMAPE per fold
runner.show_temporal()      # Stacked area over time
runner.show_mroi()          # Response curves with CIs

Key Features

Distribution-Aware Modeling

TreeMMM auto-detects the outcome distribution and selects the appropriate objective function:

Distribution Objective When to Use
Gaussian MSE Continuous, symmetric (revenue, value sales)
Poisson Log-link Non-negative counts (Rx, orders, NPS)
Tweedie Log-link Zero-inflated continuous (revenue with stockouts)
Gamma Log-link Strictly positive continuous (per-transaction revenue)

Link-Function-Aware Attribution

SHAP values live in different spaces depending on the objective. TreeMMM's decomposer handles this automatically:

  • Identity link (Gaussian): SHAP values are directly additive on the response scale
  • Log link (Poisson/Tweedie/Gamma): Proportional allocation attribution_i = (|SHAP_i| / sum|SHAP_j|) * prediction ensures attributions always sum to the predicted outcome

Automatic Interaction Discovery

Every existing MMM tool requires manually specifying interaction terms. TreeMMM discovers them automatically through tree split structure — no functional form specification needed.

mROI Simulation with Extrapolation Safety

Per-customer constraints are capped at observed-range values (e.g., 95th percentile). Higher aggregate engagement is achieved by spreading to more customers, not pushing individuals beyond observed bounds. Every customer-level prediction stays within the training distribution.

Reverse Causality Detection

Built-in Granger pre-test and lead variable test per promotional variable. Variables flagged for targeting bias are automatically set to lagged temporal alignment.

Demo Datasets

TreeMMM ships with four synthetic datasets with known ground-truth DGPs:

Dataset Default Size Distribution Key Features
Pharma 3,000 HCPs × 36mo NegBin Rheum/Derm HCS, rep targeting bias, 3 interactions, channel correlation
CPG 3,000 stores × 36mo Tweedie S/M/L store-size HCS, digital×trade interaction, zero-inflation
SaaS 3,000 accounts × 36mo ZI-Gamma Enterprise/SMB tier HCS, 2 interactions, zero-inflation
Linear 3,000 × 36mo Gaussian Pure linear (honesty test — GLMM should win here)
from treemmm.demo.datasets.pharma_brand import generate_pharma_dataset
ds = generate_pharma_dataset()
print(ds.ground_truth.attribution_shares)

Architecture

treemmm/
├── core/
│   ├── config.py              # RunConfig, ColumnSpec, Objective enum
│   ├── data_handler.py        # Panel diagnostics, distribution detection
│   ├── models/
│   │   ├── base.py            # Abstract BaseModel interface
│   │   ├── lightgbm_model.py  # LightGBM + Optuna + SHAP
│   │   ├── xgboost_model.py   # XGBoost (optional)
│   │   ├── catboost_model.py  # CatBoost (optional)
│   │   └── glmm_baseline.py   # statsmodels MixedLM (naive + oracle)
│   ├── temporal/
│   │   └── splitter.py        # Rolling origin + period-jump CV
│   ├── interpret/
│   │   └── shap_engine.py     # TreeExplainer wrapper
│   ├── attribution/
│   │   └── decomposer.py      # Link-function-aware decomposition
│   └── reporting/
│       ├── csv_exporter.py    # CSV outputs
│       ├── pptx_builder.py    # PowerPoint (optional)
│       └── zip_packager.py    # ZIP bundling
├── mroi/
│   └── simulator.py           # Response curves + constrained optimization
├── demo/
│   ├── generator.py           # Configurable DGP engine
│   ├── benchmark.py           # TreeMMM vs GLMM comparison
│   └── datasets/
│       ├── pharma_brand.py
│       ├── cpg_brand.py
│       ├── saas_brand.py
│       └── linear_baseline.py
├── ui/
│   ├── cli_runner.py          # CLI entry point
│   ├── notebook_runner.py     # Jupyter-optimized runner
│   └── widgets.py             # ipywidgets config builder (optional)
└── pipeline.py                # Main orchestrator: treemmm.run()

Pipeline Steps

  1. Data Ingestion — Column role declaration and validation
  2. Diagnostics — Panel balance, distribution detection, reverse causality test
  3. Configuration — Objective function, temporal alignment, CV strategy
  4. Training — Optuna-tuned GBT with temporal cross-validation
  5. Attribution — SHAP TreeExplainer + link-function-aware decomposition
  6. Reporting — CSVs, PowerPoint, ZIP bundle
  7. mROI (optional) — Response curves with bootstrap CIs, constrained reallocation

Supported Models

Model Install Objectives
LightGBM Core Gaussian, Poisson, Tweedie, Gamma
XGBoost pip install treemmm[xgboost] Gaussian, Poisson, Tweedie, Gamma
CatBoost pip install treemmm[catboost] Gaussian, Poisson, Tweedie (Gamma→Tweedie fallback)
GLMM Core (statsmodels) Identity link (baseline comparison)

Honest Tradeoffs

TreeMMM is not a universal replacement for Bayesian MMM. Use Bayesian methods when:

  • Strong, validated domain priors exist
  • Data is extremely limited (< 20 time periods)
  • Full posterior distributions are required
  • Classical statistical inference is needed

TreeMMM is strongest when:

  • Managing portfolios of 10+ brands with heterogeneous data
  • Multicollinearity between channels is severe
  • Non-linear response and interactions are expected but unknown
  • Speed of iteration matters (seconds vs. hours)
  • You want to discover patterns rather than confirm pre-specified hypotheses

SHAP and Causality

TreeMMM's SHAP attribution occupies a specific position on the causal identification spectrum: conditional counterfactual simulation. Panel data with temporal alignment establishes causal ordering; monotone constraints enforce domain-consistent directionality; and TreeSHAP's tree-path-dependent algorithm respects the conditional distribution (not marginalizing features independently). Under conditional exchangeability (no unmeasured confounders given observed state variables), these attributions approximate conditional causal effects.

For within-distribution budget reallocation (+/- 50% of current channel allocations), this is practically sufficient. For launching entirely new channels or settings with severe unobserved confounding, experimental validation remains necessary.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

treemmm-0.1.0.tar.gz (2.5 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

treemmm-0.1.0-py3-none-any.whl (78.9 kB view details)

Uploaded Python 3

File details

Details for the file treemmm-0.1.0.tar.gz.

File metadata

  • Download URL: treemmm-0.1.0.tar.gz
  • Upload date:
  • Size: 2.5 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.12

File hashes

Hashes for treemmm-0.1.0.tar.gz
Algorithm Hash digest
SHA256 512f4034e7e58807d6eeb107d95cd8506fb29a174df71c0b7f46fd84ac6f4ea4
MD5 dc73a7beaf35062c2a5ad237d6470b14
BLAKE2b-256 5de7e5f1e2082d92f6c9e879f90bb63508910f96dc330a3035c197e21e7a0ac5

See more details on using hashes here.

File details

Details for the file treemmm-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: treemmm-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 78.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.12

File hashes

Hashes for treemmm-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1dfc93efa976fdbd22a44b6d09e509d7241df515b2a3066e2231e453f229f19a
MD5 103463b0ef523157202fce02cdf2130c
BLAKE2b-256 d37c8bee251b2e293042dfadc038ef29b56ef32c9e625eeb48bbb97e62ad8b13

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page