Skip to main content

Exact Integrated Gradients for tree ensembles.

Project description

TreeIG

TreeIG computes exact Integrated Gradients for tree ensembles. It decomposes the change in a fitted tree model's scalar output between a baseline input $x_0$ and an observation $x$ into additive feature contributions.

For each observation, TreeIG returns feature attributions $\phi_j$ satisfying

sum_j phi_j = F(x) - F(x0)

where $F$ is the scalar model output being explained. For regression models, $F$ is the prediction. For supported classifiers, $F$ is the raw margin/logit, not the predicted probability.

TreeIG extends the Integrated Gradients framework of Sundararajan, Taly, and Yan (2017) to tree ensembles by exploiting the piecewise-constant structure of tree models.

TreeIG uses generalized gradients to extend Integrated Gradients to tree-based models. The integrals of the generalized gradients are exactly equal to the sum of the prediction steps along the input path. TreeIG uses this equivalence to efficiently compute Integrated Gradients for tree models.

References

TreeIG:

Integrated Gradients:

  • Sundararajan, Mukund, Ankur Taly, and Qiqi Yan. 2017. "Axiomatic Attribution for Deep Networks." International Conference on Machine Learning (ICML).

SHAP and TreeSHAP:

  • Lundberg, Scott M., and Su-In Lee. 2017. "A Unified Approach to Interpreting Model Predictions." Advances in Neural Information Processing Systems (NeurIPS).

  • Lundberg, Scott M., Gabriel Erion, and Su-In Lee. 2020. "From Local Explanations to Global Understanding with Explainable AI for Trees." Nature Machine Intelligence.

Popular implementations of Integrated Gradients for smooth models include:

Why TreeIG?

Standard Integrated Gradients defines feature contributions by integrating model gradients along a path from a baseline input to the observation. Tree models are piecewise constant, so ordinary gradients are zero almost everywhere and undefined at split boundaries.

TreeIG uses the tree structure directly. Along the straight-line path

x(t) = x0 + t * (x - x0),    0 <= t <= 1,

a tree prediction changes only when the path crosses a split threshold. TreeIG finds those crossings exactly and assigns each jump in prediction to the feature responsible for the crossing. For ensembles, contributions are summed across trees.

This gives an exact additive decomposition for tree models without numerical quadrature.

Relation to SHAP and TreeSHAP

TreeIG and TreeSHAP answer different attribution questions.

TreeSHAP computes Shapley-value attributions based on conditional or interventional feature perturbations. Its contributions measure how features contribute to the model prediction relative to a reference distribution over feature subsets.

TreeIG instead explains the realized change in model output along a specific path from a baseline input $x_0$ to an observation $x$. The attribution is therefore path-based rather than subset-based.

For smooth models, TreeIG reduces to ordinary Integrated Gradients. For tree models, TreeIG computes the exact path decomposition implied by split crossings.

Neither framework dominates the other. They address different counterfactual questions and therefore produce different decompositions.

Supported models

TreeIG currently supports finite numeric inputs for these model classes.

Regression

  • sklearn.tree.DecisionTreeRegressor
  • sklearn.ensemble.RandomForestRegressor
  • sklearn.ensemble.ExtraTreesRegressor
  • sklearn.ensemble.GradientBoostingRegressor
  • xgboost.XGBRegressor
  • xgboost.Booster
  • lightgbm.LGBMRegressor
  • lightgbm.Booster

Classification, raw margins only

  • sklearn.ensemble.GradientBoostingClassifier
  • xgboost.XGBClassifier
  • lightgbm.LGBMClassifier

For classification models, TreeIG attributes raw scores, margins, or logits. It does not currently attribute predicted probabilities.

Not currently supported

TreeIG deliberately does not yet support:

  • probability-output attribution;
  • missing-value routing;
  • categorical splits;
  • CatBoost;
  • probability-averaging or vote-share classifiers such as DecisionTreeClassifier, RandomForestClassifier, and ExtraTreesClassifier.

Installation

pip install treeig

Or locally:

pip install -e .

Basic usage

import numpy as np
import treeig as tig

# model is a fitted supported tree model
x0 = X_train.mean(axis=0)
X_eval = X_test[:100]

ig = tig.TreeIG(model, baseline=x0)
phi = ig.attribute(X_eval)

phi has the same shape as X_eval. Row i, column j is the contribution of feature j to the model-output change from x0 to X_eval[i].

For regression models:

np.testing.assert_allclose(
    phi.sum(axis=1),
    model.predict(X_eval) - model.predict(x0.reshape(1, -1))[0],
)

Diagnostics

Use explain when you want attributions together with completeness diagnostics.

ig = tig.TreeIG(model, baseline=x0)
phi, infos, summary = ig.explain(X_eval)

print(summary)

Each entry in infos contains diagnostics for one observation:

{
    "n_events": ...,          # number of split-crossing events
    "endpoint_delta": ...,    # F(x) - F(x0)
    "attribution_sum": ...,   # sum_j phi_j
    "residual": ...,          # attribution_sum - endpoint_delta
    "abs_residual": ...,
}

The summary dictionary reports aggregate residual and event-count statistics.

Classification targets

For binary additive-score classifiers, target=None and target=1 both attribute the positive-class margin. target=0 attributes the negative margin, implemented as the negative of the positive-class margin.

ig = tig.TreeIG(model, baseline=x0, target=1)
phi_pos = ig.attribute(X_eval)

ig = tig.TreeIG(model, baseline=x0, target=0)
phi_neg = ig.attribute(X_eval)

For multiclass classifiers, pass the class index explicitly.

ig = tig.TreeIG(model, baseline=x0, target=2)
phi_class_2 = ig.attribute(X_eval)

TreeIG attributes raw class margins. If probability-space explanations are needed, users should transform or interpret the margin-level contributions separately.

Warmup

TreeIG uses Numba for fast attribution kernels. The first call may include compilation time. You can compile the kernels in advance with warmup.

ig = tig.TreeIG(model, baseline=x0).warmup(X_eval[:3])
phi = ig.attribute(X_eval)

Functional interface

TreeIG also provides a direct functional interface.

phi, infos, summary = tig.compute(
    model,
    baseline=x0,
    X=X_eval,
)

For backward compatibility, the following aliases are also available:

from treeig import (
    exact_gb_ig_batch_fast,
    warmup_exact_gb_ig,
    timed_call,
)

Numerical conventions

TreeIG follows each backend's split-routing convention as closely as possible.

  • scikit-learn trees route left when x[j] <= threshold;
  • LightGBM numeric splits route left when x[j] <= threshold;
  • XGBoost numeric splits route left when x[j] < threshold using float32-style comparisons.

Inputs must be finite numeric arrays. Missing-value routing is not currently implemented, so NaN and Inf values raise errors.

Baselines

The baseline x0 defines the reference point for the decomposition. Common choices include:

  • the training-sample mean;
  • a median or representative observation;
  • a domain-specific neutral input;
  • a fixed benchmark case.

The attribution always explains the difference between the model output at the observation and the model output at the chosen baseline. Different baselines answer different questions.

Interpretation

For an observation x, TreeIG reports how much each feature contributes to moving the model output from F(x0) to F(x) along the straight-line path from x0 to x.

Positive contributions increase the scalar output relative to the baseline. Negative contributions decrease it. The contributions are additive by construction.

Example: XGBoost regression

import numpy as np
import xgboost as xgb
import treeig as tig

model = xgb.XGBRegressor(
    n_estimators=100,
    max_depth=3,
    learning_rate=0.05,
    objective="reg:squarederror",
    random_state=0,
)

model.fit(X_train, y_train)

x0 = X_train.mean(axis=0)
X_eval = X_test[:100]

ig = tig.TreeIG(model, baseline=x0).warmup(X_eval[:3])

phi, infos, summary = ig.explain(X_eval)

print(phi.shape)
print(summary["max_abs_residual"])

Example: multiclass classification margins

import lightgbm as lgb
import treeig as tig

model = lgb.LGBMClassifier(...)
model.fit(X_train, y_train)

x0 = X_train.mean(axis=0)
X_eval = X_test[:100]

# Attribute class-2 raw margin
ig = tig.TreeIG(model, baseline=x0, target=2)

phi = ig.attribute(X_eval)

Project status

TreeIG is intended for exact additive attribution of fitted tree models in raw-output space. The current implementation focuses on correctness, backend-specific routing consistency, and a compact API.

Future extensions may include:

  • probability-space attribution;
  • missing-value routing;
  • categorical splits;
  • CatBoost support;
  • additional attribution paths and allocation rules.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

treeig-0.1.1.tar.gz (23.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

treeig-0.1.1-py3-none-any.whl (20.7 kB view details)

Uploaded Python 3

File details

Details for the file treeig-0.1.1.tar.gz.

File metadata

  • Download URL: treeig-0.1.1.tar.gz
  • Upload date:
  • Size: 23.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.11

File hashes

Hashes for treeig-0.1.1.tar.gz
Algorithm Hash digest
SHA256 bc44dbba2667430c925e8a22618c99773e44237c1ca979ba1203683b4e0773cd
MD5 fbc43aeb454f10b019f8dc8d07872dd9
BLAKE2b-256 54988f87106533b6e5d3533a6b03c2ac84a7bfe35735e1bfd850f7f949e89924

See more details on using hashes here.

File details

Details for the file treeig-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: treeig-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 20.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.11

File hashes

Hashes for treeig-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b8d8e6dee1c4a35d58cd83e950ceeaf756261b95a4806c4aa0d39875016ab48e
MD5 05c4535615492c6eb2fae7071f6f14a0
BLAKE2b-256 b75631a0cbe8f4d6b8713a57cba7b5bfa92884501c8b4119859db61f193c26e7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page