Skip to main content

Average marginal effects and bootstrap standard errors for any machine learning model.

Project description

marginfx

Average marginal effects and bootstrap standard errors for any machine learning model.

Get OLS-style interpretability from scikit-learn, XGBoost, TensorFlow, and PyTorch. One function call. One tidy table.

import marginfx as mfx

model = RandomForestClassifier().fit(X_train, y_train)
result = mfx.fit(model, X, y, feature_names=feature_names)
result.summary()
=================================================================
marginfx: Average Marginal Effects
=================================================================
Observations: 1000
Bootstrap replicates: 200
Confidence level: 95%
-----------------------------------------------------------------
        term  estimate  std_error  statistic   p_value  conf_low  conf_high
         age     0.032      0.004      8.100     0.000     0.024      0.040
      income     0.008      0.001      6.300     0.000     0.006      0.010
      female    -0.012      0.003     -3.900     0.000    -0.018     -0.006
   education     0.021      0.005      4.200     0.000     0.011      0.031
=================================================================

What is this?

In classical econometrics, OLS gives you a coefficient table — estimates, standard errors, p-values — in units that are immediately interpretable. A one-unit increase in age increases income by $X. Everyone understands that.

Modern ML models (random forests, neural nets, gradient boosting) give you better predictions but no such table. You get a black box.

marginfx bridges the gap. It computes average marginal effects (AMEs) — the same quantity that OLS reports as its coefficients — for any model. A one-unit increase in age increases P(default) by 0.032 percentage points, regardless of whether the underlying model is a random forest or a neural net.

Standard errors come from a nonparametric bootstrap with warm-start reinitialization, making the computation practical even for expensive models. For TensorFlow and PyTorch, exact gradients replace finite differences automatically.

The output is a tidy DataFrame, directly inspired by the broom package in R and the marginaleffects package — now available for the Python ML ecosystem.


Installation

pip install marginfx

Install with the ML frameworks you use:

pip install marginfx[sklearn]              # scikit-learn + XGBoost + LightGBM
pip install marginfx[tensorflow]           # TensorFlow / Keras
pip install marginfx[pytorch]              # PyTorch
pip install marginfx[all]                  # everything

Quick start

scikit-learn

import marginfx as mfx
from sklearn.ensemble import RandomForestClassifier

model = RandomForestClassifier(n_estimators=100).fit(X_train, y_train)

result = mfx.fit(
    model, X, y,
    feature_names=feature_names,
    n_bootstrap=200,
    seed=42,
)

result.summary()          # formatted table
result.tidy()             # pandas DataFrame

XGBoost

import marginfx as mfx
import xgboost as xgb

model = xgb.XGBClassifier().fit(X_train, y_train)
result = mfx.fit(model, X, y, feature_names=feature_names, seed=42)
result.summary()

TensorFlow / Keras

import marginfx as mfx
import tensorflow as tf

model = tf.keras.models.load_model("my_model.keras")

result = mfx.fit(
    model, X, y,
    feature_names=feature_names,
    n_epochs=10,       # bootstrap warm-start epochs
    seed=42,
)
result.summary()

PyTorch

import marginfx as mfx
import torch.nn as nn

result = mfx.fit(
    model, X, y,
    feature_names=feature_names,
    loss_fn=nn.BCELoss(),
    optimizer_fn=lambda p: torch.optim.Adam(p, lr=1e-3),
    n_epochs=10,
    seed=42,
)
result.summary()

Pandas DataFrames

# Column names are picked up automatically
result = mfx.fit(model, df[features], df["target"])
result.tidy()

Categorical features

# Categorical features use first differences (0 -> 1) instead of derivatives
result = mfx.fit(
    model, X, y,
    feature_names=feature_names,
    categorical_features=["female", "married", "has_degree"],
)

The tidy output

result.tidy() returns a pandas DataFrame modeled on broom::tidy() in R:

term estimate std_error statistic p_value conf_low conf_high
age 0.032 0.004 8.10 0.000 0.024 0.040
income 0.008 0.001 6.30 0.000 0.006 0.010
female -0.012 0.003 -3.90 0.000 -0.018 -0.006
  • estimate — the average marginal effect (AME): mean of pointwise dy/dx across all observations
  • std_error — bootstrap standard deviation across replicates
  • statistic — estimate / std_error (normal approximation)
  • p_value — two-tailed p-value under normal approximation
  • conf_low / conf_high — percentile bootstrap confidence interval

How it works

Average marginal effects

For a continuous feature x_j, the marginal effect at observation i is:

ME_i(x_j) = ∂f(x_i) / ∂x_j

Approximated via central finite differences:

ME_i(x_j) ≈ [f(x_i + h·e_j) - f(x_i - h·e_j)] / 2h

The AME is the mean across all observations:

AME(x_j) = (1/n) Σ ME_i(x_j)

For binary/categorical features, a first difference replaces the derivative:

ME_i(x_j) = f(x_i | x_j=1) - f(x_i | x_j=0)

For TensorFlow and PyTorch models, tf.GradientTape and torch.autograd provide exact gradients, replacing finite differences automatically.

Bootstrap standard errors

Standard errors come from a nonparametric bootstrap:

  1. Resample the data with replacement
  2. Refit the model warm-starting from the original (faster convergence)
  3. Compute AMEs on the bootstrap sample
  4. Repeat B times
  5. SE = standard deviation of the B AME estimates
  6. CI = percentile interval of the B AME estimates

Warm-starting from the original model makes the bootstrap practical for expensive models — bootstrap replicates converge in far fewer iterations than cold retraining.


Supported models

Framework Models Gradient method Warm-start
scikit-learn RandomForest, GradientBoosting, LogisticRegression, LinearRegression, SVC, and all sklearn-compatible models Finite differences Yes (where supported)
XGBoost XGBClassifier, XGBRegressor Finite differences Yes (native)
LightGBM LGBMClassifier, LGBMRegressor Finite differences Yes (native)
TensorFlow tf.keras.Model Exact (GradientTape) Yes (continued training)
PyTorch torch.nn.Module Exact (autograd) Yes (continued training)

Model type is detected automatically. No need to specify the engine.


API reference

mfx.fit()

mfx.fit(
    model,                        # fitted model — any supported type
    X,                            # feature matrix (numpy array or pandas DataFrame)
    y,                            # target vector
    feature_names=None,           # list of feature names (auto from DataFrame columns)
    categorical_features=None,    # list of categorical feature indices or names
    n_bootstrap=200,              # number of bootstrap replicates
    alpha=0.05,                   # significance level (0.05 = 95% CI)
    seed=None,                    # random seed for reproducibility
    verbose=True,                 # print bootstrap progress
    h=1e-4,                       # finite difference step size (sklearn models)
    n_epochs=10,                  # bootstrap refit epochs (TF/PyTorch)
    batch_size=32,                # bootstrap refit batch size (TF/PyTorch)
    optimizer_fn=None,            # optimizer callable (PyTorch only)
    loss_fn=None,                 # loss function (PyTorch only)
)

Returns a MarginfxResult object.

MarginfxResult

result.tidy()        # pandas DataFrame with estimates, SEs, CIs
result.summary()     # formatted summary table printed to stdout
result.estimates     # dict of feature -> AME estimate
result.std_errors    # dict of feature -> bootstrap SE
result.conf_int      # dict of feature -> (conf_low, conf_high)
result.n_obs         # number of observations
result.n_bootstrap   # number of bootstrap replicates

Citation

If you use marginfx in published research, please cite:

@inproceedings{marginfx2026,
  title     = {marginfx: Average Marginal Effects for Any Machine Learning Model},
  author    = {Your Name},
  booktitle = {Proceedings of the 26th IEEE International Conference on Data Mining (ICDM)},
  year      = {2026},
  address   = {Shenyang, China},
}

Related work

  • marginaleffects — the R package that inspired this project
  • broom — tidy model output in R
  • shap — SHAP values for model explanation
  • lime — local interpretable model-agnostic explanations

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

marginfx-0.1.0.tar.gz (220.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

marginfx-0.1.0-py3-none-any.whl (22.7 kB view details)

Uploaded Python 3

File details

Details for the file marginfx-0.1.0.tar.gz.

File metadata

  • Download URL: marginfx-0.1.0.tar.gz
  • Upload date:
  • Size: 220.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for marginfx-0.1.0.tar.gz
Algorithm Hash digest
SHA256 9611cdaddc7443e0cfb5336e1e25ba7e5ed1353b8eda387043cf8e3bacfb0ec9
MD5 4f04178ed2f42c00b7bd59ec52a1be0b
BLAKE2b-256 5763f88f8a9c9edaee80c24e4e183e2e933b434e8af826f5c79d59e004595927

See more details on using hashes here.

File details

Details for the file marginfx-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: marginfx-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 22.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for marginfx-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 df26b8e5a6b48e3a819d265b20821672c12f5b3916cc1dcbf844e44fc8385479
MD5 53211f9da4b8d29d0b46a22fd122001a
BLAKE2b-256 b04f63b89d0eb73a27f3d984950f27cb2772dbe3c3023e119f5dab6b876f793b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page