Skip to main content

Compute WeightWatcher-style correlation matrices (W1/W2/W7/W8) for XGBoost via OOF margin increments.

Project description

Why XGBoost2WW?

XGBoost2WW lets you apply WeightWatcher-style spectral diagnostics to XGBoost models.

XGBoost models don’t have traditional neural network weight matrices — so you can’t directly run tools like WeightWatcher on them.
XGBoost2WW bridges that gap by converting a trained XGBoost model into structured matrices (W1/W2/W7/W8) derived from out-of-fold margin increments along the boosting trajectory.

These matrices behave like neural weight matrices, so you can analyze them with WeightWatcher.


Why would a production ML engineer care?

Because traditional metrics (accuracy, AUC, logloss) often look fine right up until a model fails in production.

Spectral diagnostics can help detect:

  • Overfitting that standard validation doesn’t reveal
  • Correlation traps in boosted trees
  • Excessive memorization
  • Unstable training dynamics
  • Data leakage patterns
  • Models that are brittle to distribution shift

In short:

XGBoost2WW gives you a structural diagnostic signal — not just a performance metric.

That means you can:

  • Compare model candidates beyond accuracy
  • Detect problematic models before deployment
  • Monitor structural drift over time
  • Add an extra safety layer to your MLOps pipeline

If you deploy XGBoost models in production,
XGBoost2WW gives you a new lens to inspect them.

xgboost2ww

Convert XGBoost boosting dynamics into WeightWatcher-style correlation matrices (W1/W2/W7/W8).

Install

Development install:

pip install -e .
pip install weightwatcher torch

Minimal runtime install (for a future PyPI install):

pip install xgboost2ww
pip install weightwatcher

Quickstart (compute_matrices)

import numpy as np
import xgboost as xgb

from xgboost2ww import compute_matrices

rng = np.random.default_rng(0)
X = rng.normal(size=(300, 12)).astype(np.float32)
logits = 1.5 * X[:, 0] - 0.8 * X[:, 1] + 0.3 * rng.normal(size=300)
y = (logits > 0).astype(np.int32)

dtrain = xgb.DMatrix(X, label=y)

params = {
    "objective": "binary:logistic",
    "eval_metric": "logloss",
    "max_depth": 3,
    "eta": 0.1,
    "subsample": 1.0,
    "colsample_bytree": 1.0,
    "seed": 0,
    "verbosity": 0,
}
rounds = 40
bst = xgb.train(params, dtrain, num_boost_round=rounds)

# Reproducibility knobs for fold training inside compute_matrices / convert
train_params = params
num_boost_round = rounds

mats = compute_matrices(
    bst,
    X,
    y,
    nfolds=5,
    t_points=40,
    random_state=0,
    train_params=train_params,
    num_boost_round=num_boost_round,
)

W7 = mats.W7
print(W7.shape)

Quickstart (convert + WeightWatcher)

import weightwatcher as ww

from xgboost2ww import convert

layer = convert(
    bst,
    X,
    y,
    W="W7",
    return_type="torch",
    nfolds=5,
    t_points=40,
    random_state=0,
    train_params=train_params,
    num_boost_round=num_boost_round,
)

watcher = ww.WeightWatcher(model=layer)
details_df = watcher.analyze(randomize=True, plot=False)

alpha = details_df["alpha"].iloc[0]
rand_num_spikes = details_df["rand_num_spikes"].iloc[0]
print({"alpha": alpha, "rand_num_spikes": rand_num_spikes})

For initial evaluation, you do not need detX=True. If you want determinant-based diagnostics, you can pass detX=True.

Notes / limitations

  • Binary classification is the default workflow.
  • Multiclass requires setting multiclass explicitly (supported modes: "per_class", "stack", "avg").
  • convert(..., multiclass="per_class", return_type="torch") is unsupported and raises; for multiclass per-class output, use return_type="numpy".
  • torch is optional unless you need convert(..., return_type="torch").

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xgboost2ww-0.1.0.tar.gz (21.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

xgboost2ww-0.1.0-py3-none-any.whl (16.2 kB view details)

Uploaded Python 3

File details

Details for the file xgboost2ww-0.1.0.tar.gz.

File metadata

  • Download URL: xgboost2ww-0.1.0.tar.gz
  • Upload date:
  • Size: 21.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for xgboost2ww-0.1.0.tar.gz
Algorithm Hash digest
SHA256 a7f690f7c73b907f2d64119a85f162e920378c07b27dd91f4dbae26f4989ed38
MD5 522559714f84380737268a42fc167523
BLAKE2b-256 632d3e9f05e73003feb68a4b77b9025d9ae7eec47c90d555d56fc89995d99bf8

See more details on using hashes here.

File details

Details for the file xgboost2ww-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: xgboost2ww-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 16.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for xgboost2ww-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c85359832a5861c48fc96ca0f8afaa6778c5f3e40a363201a3d3a4104359bc04
MD5 bf5193a9115146a977ad4b418c0c7ba5
BLAKE2b-256 3b42ea5809a641845c1759ec1ba3c64afb67b13e4872bc31944682dd9d9a5786

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page