Compute WeightWatcher-style correlation matrices (W1/W2/W7/W8) for XGBoost via OOF margin increments.
Project description
Why XGBoost2WW?
XGBoost2WW lets you apply WeightWatcher-style spectral diagnostics to XGBoost models.
XGBoost models don’t have traditional neural network weight matrices — so you can’t directly run tools like WeightWatcher on them.
XGBoost2WW bridges that gap by converting a trained XGBoost model into structured matrices (W1/W2/W7/W8/W9) derived from out-of-fold margin increments along the boosting trajectory.
These matrices behave like neural weight matrices, so you can analyze them with WeightWatcher.
Why would a production ML engineer care?
Because traditional metrics (accuracy, AUC, logloss) often look fine right up until a model fails in production.
Spectral diagnostics can help detect:
- Overfitting that standard validation doesn’t reveal
- Correlation traps in boosted trees
- Excessive memorization
- Unstable training dynamics
- Data leakage patterns
- Models that are brittle to distribution shift
In short:
XGBoost2WW gives you a structural diagnostic signal — not just a performance metric.
That means you can:
- Compare model candidates beyond accuracy
- Detect problematic models before deployment
- Monitor structural drift over time
- Add an extra safety layer to your MLOps pipeline
If you deploy XGBoost models in production,
XGBoost2WW gives you a new lens to inspect them.
xgboost2ww
Convert XGBoost boosting dynamics into WeightWatcher-style operators (W1/W2/W7/W8/W9).
Install
Development install:
pip install -e .
pip install weightwatcher torch
Minimal runtime install:
pip install xgboost2ww
pip install weightwatcher
Google Colab Notebooks
Single Good Model
- High test and training accuracy, good WW metrics
SingleGoodModelWWXGBoost2WW.ipynb
Realistic End-to-End Example
- Interpreting α and traps in a realistic, non-trivial setting
XGBoost2WWAdultIncomeExample.ipynb
Stress Test Across 100 Random Models
- Trains 100 random models, analyzes weightwatcher alpha vs test accuracy
GoodModelsXGBoost2WW.ipynb
Poorly Trained Credit Model
- Small data set, hard to get high test accuracy, shows high alpha
PoorlyTrainedCreditModel.ipynb
Diagnostic Example
- Overly simple model where the training data is strongly overfit
XGBoost2WWDiagnosticExample.ipynb
SpamBase Alpha=2 Targeted Sweep
- Targeted hyperparameter sweep to maximize validation accuracy near α≈2.0
SpamBase_Hyperparameter_Sweep_Alpha2_Targeted.ipynb
Random100 Long-Run Alpha Tracking
- Long-running Random100 catalog training with Google Drive checkpoints and restart/resume support
XGBWW_Catalog_Random100_XGBoost_Accuracy_LongRun_AlphaTracking.ipynb
Quickstart (compute_matrices)
import numpy as np
import xgboost as xgb
from xgboost2ww import compute_matrices
rng = np.random.default_rng(0)
X = rng.normal(size=(300, 12)).astype(np.float32)
logits = 1.5 * X[:, 0] - 0.8 * X[:, 1] + 0.3 * rng.normal(size=300)
y = (logits > 0).astype(np.int32)
dtrain = xgb.DMatrix(X, label=y)
params = {
"objective": "binary:logistic",
"eval_metric": "logloss",
"max_depth": 3,
"eta": 0.1,
"subsample": 1.0,
"colsample_bytree": 1.0,
"seed": 0,
"verbosity": 0,
}
rounds = 40
bst = xgb.train(params, dtrain, num_boost_round=rounds)
# Reproducibility knobs for fold training inside compute_matrices / convert
train_params = params
num_boost_round = rounds
mats = compute_matrices(
bst,
X,
y,
nfolds=5,
t_points=40,
random_state=0,
train_params=train_params,
num_boost_round=num_boost_round,
)
W9 = mats.W9
print(W9.shape)
Quickstart (convert + WeightWatcher)
import weightwatcher as ww
from xgboost2ww import convert
layer = convert(
bst,
X,
y,
W="W9",
return_type="torch",
nfolds=5,
t_points=40,
random_state=0,
train_params=train_params,
num_boost_round=num_boost_round,
)
watcher = ww.WeightWatcher(model=layer)
details_df = watcher.analyze(randomize=True, plot=False)
alpha = details_df["alpha"].iloc[0]
rand_num_spikes = details_df["rand_num_spikes"].iloc[0]
print({"alpha": alpha, "rand_num_spikes": rand_num_spikes})
For initial evaluation, you do not need detX=True. If you want determinant-based diagnostics, you can pass detX=True.
What the Diagnostics Look Like
Below is an example of a WeightWatcher spectral analysis applied to an XGBoost model via xgboost2ww.
In this example:
- (Left) The power-law fit produces an α value near 2
- (Right) The ERG detX condition is pretty good (red and purple lines are close)
- (not shown0 No significant traps are detected
But unlike many Neural Networksm the full empirical spectral density (ESD) shows very-heavy-tailed structure. This indicates that the training data was very easy to memorize. That's OK. Even expected.
This is what a structurally healthy model looks like.
When α drifts upward or traps appear, it is often a signal of:
- Overfitting
- Correlation traps
- Memorization
- Instability in training
- Data leakage
- Structural brittleness
Matrix defaults and stability
convert(...)now defaults toW="W1"(instead ofW="W7").W1is the recommended non-experimental default.- All other
Wvariants (W2,W7,W8,W9) are currently experimental.
Matrix definitions at a glance
- W8: legacy practical surrogate based on weighted/centered
W7. - W9: canonical regularizer-whitened Fisher OOF trajectory matrix.
For binary classification, with raw OOF increments dF_oof and final OOF margin m_final:
p = sigmoid(m_final)
h = clip(p * (1 - p), eps, None)
A = weighted_center_cols(dF_oof, h)
gamma_diag[j] = lambda * sum_{trees in endpoint block j}(leaf_value^2)
W9 = diag(sqrt(h)) · A · diag(gamma_diag^{-1/2})
W9 v1 uses the local-quadratic L2 regularizer mass (lambda) and intentionally
ignores the non-smooth L1 (alpha) term.
Notes / limitations
- Binary classification is the default workflow.
- Multiclass requires setting
multiclassexplicitly (supported modes:"per_class","stack","avg"). convert(..., multiclass="per_class", return_type="torch")is unsupported and raises; for multiclass per-class output, usereturn_type="numpy".torchis optional unless you needconvert(..., return_type="torch").
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file xgboost2ww-0.1.1.tar.gz.
File metadata
- Download URL: xgboost2ww-0.1.1.tar.gz
- Upload date:
- Size: 26.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c57cdcc04d3fec1b28eb16d6e8c58650ba705391ce0345e5cfbd8646434115f2
|
|
| MD5 |
9a729b00a7f50dfee9c882cade3ded0d
|
|
| BLAKE2b-256 |
bc9f5918b3e010d667594902b4e1f1d554e6ff5412593a985c66cc68f88e9254
|
File details
Details for the file xgboost2ww-0.1.1-py3-none-any.whl.
File metadata
- Download URL: xgboost2ww-0.1.1-py3-none-any.whl
- Upload date:
- Size: 18.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ce38451b0d9a1623426c477fd1e14408c5918a76580f58f26c4f0f5ef2ec3cb3
|
|
| MD5 |
b6fad47ed15705795cd48f83034b5bf6
|
|
| BLAKE2b-256 |
a43f4cb3d78992c093198865309f7ce30f03f7908837a6d2274cb36c6d349220
|