Skip to main content

Compute WeightWatcher-style correlation matrices (W1/W2/W7/W8) for XGBoost via OOF margin increments.

Project description

Why XGBoost2WW?

XGBoost2WW lets you apply WeightWatcher-style spectral diagnostics to XGBoost models.

XGBoost models don’t have traditional neural network weight matrices — so you can’t directly run tools like WeightWatcher on them.
XGBoost2WW bridges that gap by converting a trained XGBoost model into structured matrices (W1/W2/W7/W8/W9) derived from out-of-fold margin increments along the boosting trajectory.

These matrices behave like neural weight matrices, so you can analyze them with WeightWatcher.


Why would a production ML engineer care?

Because traditional metrics (accuracy, AUC, logloss) often look fine right up until a model fails in production.

Spectral diagnostics can help detect:

  • Overfitting that standard validation doesn’t reveal
  • Correlation traps in boosted trees
  • Excessive memorization
  • Unstable training dynamics
  • Data leakage patterns
  • Models that are brittle to distribution shift

In short:

XGBoost2WW gives you a structural diagnostic signal — not just a performance metric.

That means you can:

  • Compare model candidates beyond accuracy
  • Detect problematic models before deployment
  • Monitor structural drift over time
  • Add an extra safety layer to your MLOps pipeline

If you deploy XGBoost models in production,
XGBoost2WW gives you a new lens to inspect them.

xgboost2ww

Convert XGBoost boosting dynamics into WeightWatcher-style operators (W1/W2/W7/W8/W9).

Install

Development install:

pip install -e .
pip install weightwatcher torch

Minimal runtime install:

pip install xgboost2ww
pip install weightwatcher

Google Colab Notebooks

Single Good Model

  • High test and training accuracy, good WW metrics SingleGoodModelWWXGBoost2WW.ipynb

Realistic End-to-End Example

  • Interpreting α and traps in a realistic, non-trivial setting XGBoost2WWAdultIncomeExample.ipynb

Stress Test Across 100 Random Models

  • Trains 100 random models, analyzes weightwatcher alpha vs test accuracy GoodModelsXGBoost2WW.ipynb

Poorly Trained Credit Model

  • Small data set, hard to get high test accuracy, shows high alpha PoorlyTrainedCreditModel.ipynb

Diagnostic Example

  • Overly simple model where the training data is strongly overfit XGBoost2WWDiagnosticExample.ipynb

SpamBase Alpha=2 Targeted Sweep

Random100 Long-Run Alpha Tracking

Quickstart (compute_matrices)

import numpy as np
import xgboost as xgb

from xgboost2ww import compute_matrices

rng = np.random.default_rng(0)
X = rng.normal(size=(300, 12)).astype(np.float32)
logits = 1.5 * X[:, 0] - 0.8 * X[:, 1] + 0.3 * rng.normal(size=300)
y = (logits > 0).astype(np.int32)

dtrain = xgb.DMatrix(X, label=y)

params = {
    "objective": "binary:logistic",
    "eval_metric": "logloss",
    "max_depth": 3,
    "eta": 0.1,
    "subsample": 1.0,
    "colsample_bytree": 1.0,
    "seed": 0,
    "verbosity": 0,
}
rounds = 40
bst = xgb.train(params, dtrain, num_boost_round=rounds)

# Reproducibility knobs for fold training inside compute_matrices / convert
train_params = params
num_boost_round = rounds

mats = compute_matrices(
    bst,
    X,
    y,
    nfolds=5,
    t_points=40,
    random_state=0,
    train_params=train_params,
    num_boost_round=num_boost_round,
)

W9 = mats.W9
print(W9.shape)

Quickstart (convert + WeightWatcher)

import weightwatcher as ww

from xgboost2ww import convert

layer = convert(
    bst,
    X,
    y,
    W="W9",
    return_type="torch",
    nfolds=5,
    t_points=40,
    random_state=0,
    train_params=train_params,
    num_boost_round=num_boost_round,
)

watcher = ww.WeightWatcher(model=layer)
details_df = watcher.analyze(randomize=True, plot=False)

alpha = details_df["alpha"].iloc[0]
rand_num_spikes = details_df["rand_num_spikes"].iloc[0]
print({"alpha": alpha, "rand_num_spikes": rand_num_spikes})

For initial evaluation, you do not need detX=True. If you want determinant-based diagnostics, you can pass detX=True.

What the Diagnostics Look Like

Below is an example of a WeightWatcher spectral analysis applied to an XGBoost model via xgboost2ww.

In this example:

  • (Left) The power-law fit produces an α value near 2
  • (Right) The ERG detX condition is pretty good (red and purple lines are close)
  • (not shown0 No significant traps are detected

But unlike many Neural Networksm the full empirical spectral density (ESD) shows very-heavy-tailed structure. This indicates that the training data was very easy to memorize. That's OK. Even expected.

This is what a structurally healthy model looks like.

When α drifts upward or traps appear, it is often a signal of:

  • Overfitting
  • Correlation traps
  • Memorization
  • Instability in training
  • Data leakage
  • Structural brittleness

Matrix defaults and stability

  • convert(...) now defaults to W="W1" (instead of W="W7").
  • W1 is the recommended non-experimental default.
  • All other W variants (W2, W7, W8, W9) are currently experimental.

Matrix definitions at a glance

  • W8: legacy practical surrogate based on weighted/centered W7.
  • W9: canonical regularizer-whitened Fisher OOF trajectory matrix.

For binary classification, with raw OOF increments dF_oof and final OOF margin m_final:

p = sigmoid(m_final)
h = clip(p * (1 - p), eps, None)
A = weighted_center_cols(dF_oof, h)
gamma_diag[j] = lambda * sum_{trees in endpoint block j}(leaf_value^2)
W9 = diag(sqrt(h)) · A · diag(gamma_diag^{-1/2})

W9 v1 uses the local-quadratic L2 regularizer mass (lambda) and intentionally ignores the non-smooth L1 (alpha) term.

Notes / limitations

  • Binary classification is the default workflow.
  • Multiclass requires setting multiclass explicitly (supported modes: "per_class", "stack", "avg").
  • convert(..., multiclass="per_class", return_type="torch") is unsupported and raises; for multiclass per-class output, use return_type="numpy".
  • torch is optional unless you need convert(..., return_type="torch").

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xgboost2ww-0.1.1.tar.gz (26.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

xgboost2ww-0.1.1-py3-none-any.whl (18.7 kB view details)

Uploaded Python 3

File details

Details for the file xgboost2ww-0.1.1.tar.gz.

File metadata

  • Download URL: xgboost2ww-0.1.1.tar.gz
  • Upload date:
  • Size: 26.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for xgboost2ww-0.1.1.tar.gz
Algorithm Hash digest
SHA256 c57cdcc04d3fec1b28eb16d6e8c58650ba705391ce0345e5cfbd8646434115f2
MD5 9a729b00a7f50dfee9c882cade3ded0d
BLAKE2b-256 bc9f5918b3e010d667594902b4e1f1d554e6ff5412593a985c66cc68f88e9254

See more details on using hashes here.

File details

Details for the file xgboost2ww-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: xgboost2ww-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 18.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for xgboost2ww-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 ce38451b0d9a1623426c477fd1e14408c5918a76580f58f26c4f0f5ef2ec3cb3
MD5 b6fad47ed15705795cd48f83034b5bf6
BLAKE2b-256 a43f4cb3d78992c093198865309f7ce30f03f7908837a6d2274cb36c6d349220

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page