Semi-supervised Gaussian mixture classifier with a weighted unlabeled likelihood
Project description
semi_supervised_gmm
Semi-supervised Gaussian mixture classification with a weighted unlabeled likelihood.
Companion to the paper:
Semi-Supervised Generative Classification via a Weighted Unlabeled Likelihood (under review, Statistics and Computing)
Overview
Many classification problems have cheap features but expensive labels. This package fits a two-component Gaussian mixture model by maximising a weighted log-likelihood:
J(θ) = ℓ_sup(θ) + λ · ℓ_unl(θ)
where ℓ_sup is the supervised log-likelihood over labeled data and ℓ_unl is the unlabeled marginal mixture log-likelihood. The scalar λ controls how much the unlabeled corpus influences the fit. EM yields closed-form updates with a guaranteed monotone increase in the objective at every iteration.
The paper treats λ as the primary object of study rather than a tuning nuisance:
- λ = 0 recovers the purely supervised MLE.
- λ > 0 borrows geometric structure from the unlabeled distribution.
- Pre-fitting diagnostic
A(0): a computable score whose sign indicates whether unlabeled data push the estimator in a helpful or harmful direction — before any semi-supervised fitting.
Installation
pip install semi-supervised-gmm
Or from source:
git clone https://github.com/aaronjdanielson/semi_supervised_gmm
cd semi_supervised_gmm
pip install -e .
Dependencies: numpy, scipy, scikit-learn
Quick start
import numpy as np
from semi_supervised_gmm import SemiSupervisedGMM, make_semi_supervised
rng = np.random.default_rng(0)
# Generate data: y=1 (positive), y=0 (negative), y=-1 (unlabeled)
X_pos = rng.multivariate_normal([2, 0], np.eye(2), 30)
X_neg = rng.multivariate_normal([-2, 0], np.eye(2), 30)
X_u = rng.multivariate_normal([0, 0], np.eye(2), 300)
X, y = make_semi_supervised(
np.vstack([X_pos, X_neg]),
np.array([1]*30 + [0]*30),
X_u,
)
# Fit
model = SemiSupervisedGMM(lambda_=1.0).fit(X, y)
# Predict
proba = model.predict_proba(X_test) # shape (n, 2): [P(y=0|x), P(y=1|x)]
labels = model.predict(X_test)
auc = model.score(X_test, y_test) # AUROC
Replicating paper results
All paper figures and tables are generated by three scripts in replication/.
Run them in order:
# Step 1: run simulation experiments and cache results (~10–30 min)
python replication/make_simulation_data.py
# Step 2: generate simulation figures and tables from cache (fast)
python replication/run_simulations_for_paper.py
# Step 3: run UCI benchmark and generate empirical figures and tables
# (~20–60 min depending on hardware; uses joblib parallelism)
python replication/run_empirical_for_paper.py
All outputs land in papers/figures/ and papers/tables/.
Simulation outputs (run_simulations_for_paper.py)
| File | Content |
|---|---|
papers/figures/fig_diagnostic.pdf |
Alignment coefficient diagnostic validity (§Simulations) |
papers/figures/fig_bias_variance.pdf |
Bias–variance tradeoff as a function of λ |
papers/figures/fig_decision.pdf |
Use/discard decision accuracy |
papers/figures/fig_conf_ablation.pdf |
Confidence-weighted ablation |
papers/figures/fig_dimension_sweep.pdf |
High-dimensional geometry sweep |
papers/figures/fig_alignment_covariance.pdf |
Alignment under covariance misspecification |
papers/tables/tab_misspec.tex |
Misspecification results (Student-t, skew-normal) |
papers/tables/tab_lambda_learn.tex |
λ-learning gradient vs grid-search comparison |
papers/tables/tab_decision.tex |
Decision accuracy table |
papers/tables/tab_regime.tex |
Regime-grid summary |
papers/tables/tab_dimension.tex |
High-dimensional geometry results |
Empirical outputs (run_empirical_for_paper.py)
| File | Content |
|---|---|
papers/figures/fig_peff_scatter.pdf |
SSL gain vs N_lab/N_unlab ratio, 13 datasets |
papers/figures/fig_auroc_curves.pdf |
AUROC vs label budget, all 13 datasets |
papers/figures/fig_distribution_grid_semi.pdf |
Per-dataset distribution, Semi(λ=1) |
papers/figures/fig_distribution_grid_learned.pdf |
Per-dataset distribution, Learned-λ |
papers/figures/fig_distribution_grid_local.pdf |
Per-dataset distribution, Local-λ |
papers/tables/tab_uci_benchmark.tex |
Primary 9-dataset benchmark table |
papers/tables/tab_uci_expanded.tex |
Expanded 13-dataset table with failure-mode annotations |
Incremental caching
run_empirical_for_paper.py saves results incrementally to
replication/cache/uci_distribution_methods_cache.csv after each dataset
completes. If the run is interrupted, restart with the same command — completed
datasets are skipped automatically.
Dependencies for replication
pip install -e ".[dev]"
pip install pandas matplotlib ucimlrepo
Four estimators
All estimators follow the sklearn interface: fit / predict / predict_proba / score.
Use y = -1 as the unlabeled sentinel, matching sklearn.semi_supervised.LabelPropagation.
SupervisedGMM
Purely supervised MLE (λ = 0). Ignores unlabeled observations.
from semi_supervised_gmm import SupervisedGMM
model = SupervisedGMM().fit(X, y)
SemiSupervisedGMM
Fixed global λ. Grid-search on a validation set via fit_cv.
from semi_supervised_gmm import SemiSupervisedGMM
# Fixed lambda
model = SemiSupervisedGMM(lambda_=2.0).fit(X, y)
# Grid-searched lambda
model = SemiSupervisedGMM().fit_cv(X, y, X_val, y_val,
lam_grid=np.logspace(-2, 2, 20))
print(model.lambda_used_) # selected value
Compatible with GridSearchCV:
from sklearn.model_selection import GridSearchCV
gs = GridSearchCV(SemiSupervisedGMM(),
{"lambda_": [0.1, 0.5, 1.0, 2.0, 5.0]},
scoring="accuracy")
gs.fit(X_labeled, y_labeled)
LearnedLambdaGMM
Learns λ by gradient ascent on validation log-likelihood (IFT-based, as described in the paper).
from semi_supervised_gmm import LearnedLambdaGMM
model = LearnedLambdaGMM(lambda_init=1.0, n_steps=10).fit(
X, y, X_val=X_val, y_val=y_val
)
print(model.lambda_) # learned value
LocalLambdaGMM
Per-point confidence-weighted λ: λ(x) = λ · max(γ(x), 1−γ(x))^α. Downweights ambiguous unlabeled points.
from semi_supervised_gmm import LocalLambdaGMM
model = LocalLambdaGMM(lambda_=1.0, alpha=1.0).fit(X, y)
Pre-fitting diagnostic: should you use unlabeled data?
The alignment coefficient A(0) is computable before any semi-supervised fitting.
Its sign indicates the direction in which unlabeled data will push the estimator.
from semi_supervised_gmm import SupervisedGMM, SemiSupervisedGMM, alignment_score, alignment_reliability
# Step 1 (optional): check reliability before computing A(0)
rel = alignment_reliability(n_lab=n_lab, d=X_train.shape[1], covariance_type="full")
# rel["reliability_regime"] -> "reliable" | "marginal" | "unreliable"
# rel["covariance_recommendation"] -> e.g. "diagonal"
# Step 2: compute alignment coefficient
result = alignment_score(X_train, y_train, X_val, y_val)
# {
# "A0": float, # alignment coefficient (nan if rank-deficient)
# "g0_norm": float, # ||g_0|| score residual norm
# "recommendation": str, # "use" | "discard" | "unreliable"
# # "use_with_caution" | "discard_with_caution"
# "n_unlabeled": int,
# "p_eff": int, # effective parameter count
# "reliability_ratio": float, # p_eff / n_lab
# "reliability_regime": str, # "reliable" | "marginal" | "unreliable"
# "covariance_recommendation": str, # suggested covariance_type (or None)
# }
if result["recommendation"] == "use":
model = SemiSupervisedGMM(lambda_=1.0).fit(X_train, y_train)
else:
model = SupervisedGMM().fit(X_train, y_train)
Interpretation:
A(0) > 0— unlabeled geometry is aligned with the classification task; consider semi-supervised learningA(0) < 0— unlabeled geometry is misaligned; prefer supervised MLEg0_normlarge (> 15) — Gaussian assumption likely violated; treat the diagnostic as unreliable
A(0) is a partial diagnostic signal, not a decision rule. Its reliability degrades as the ratio of effective parameter dimension to labeled sample size grows. The paper identifies three structural failure modes where semi-supervised learning is inadvisable regardless of the A(0) sign: misalignment of unlabeled geometry with the task, estimator instability when p_eff/N_lab is large, and prior mismatch under class imbalance.
Reliability regimes (based on p_eff / N_lab):
p_eff / N_lab |
Regime | Interpretation |
|---|---|---|
| < 10 | "reliable" |
Fisher matrix well-conditioned; A(0) informative |
| 10–65 | "marginal" |
Use shrinkage covariance ("ledoit_wolf"); treat A(0) as advisory |
| > 65 | "unreliable" |
n_lab ≪ d; A(0) returns nan; prefer supervised-only or diagonal covariance |
Effective parameter counts: p_eff = 1 + 2d + d(d+1) (full covariance) or 1 + 4d (diagonal).
Example: Sonar (d=60, full) → p_eff = 3781; at n_lab = 20, ratio = 189 → unreliable.
Benchmark summary
Results from the 13-dataset benchmark (250 seeds, covariance_type="ledoit_wolf", StandardScaler).
See papers/tables/tab_uci_expanded.tex for the full table with failure-mode annotations.
When it helps: In the label-scarce regime, semi-supervised learning consistently improves AUROC on near-Gaussian datasets (Breast Cancer, Ionosphere, Heart Disease, Diabetes, Wine) by +0.03 to +0.13. The choice of variant matters little at small label budgets.
When it doesn't: Three failure modes are documented in the paper:
- Misalignment (Sonar, d=60): unlabeled geometry does not align with the boundary;
‖g₀‖is large and correctly flags the diagnostic as unreliable. - Estimability (SPECTF, d=44, small N): rank-deficient supervised baseline; gains at very small N_lab do not persist as N_lab grows.
- Prior mismatch (Shuttle, 84% class imbalance): fixed λ=1 degrades AUROC by −0.44; learned-λ partially recovers by shrinking λ toward zero.
Covariance options
All estimators accept covariance_type:
| Value | When to use |
|---|---|
"full" (default) |
N ≫ d; unrestricted covariance |
"diag" |
d > N or when features are approximately independent |
"ledoit_wolf" |
d ≈ N; shrinkage toward scaled identity (recommended for real data) |
model = SemiSupervisedGMM(lambda_=1.0, covariance_type="ledoit_wolf").fit(X, y)
Convenience: split-data interface
All estimators expose fit_semi for callers who have labeled and unlabeled data in separate arrays:
model = SemiSupervisedGMM(lambda_=1.0).fit_semi(X_labeled, y_labeled, X_unlabeled)
Persistence
model.save("model.pkl")
loaded = SemiSupervisedGMM.load("model.pkl")
Algorithmic notes
Cholesky caching. Each E-step uses a single Cholesky factorisation per class per iteration (the reference performs two O(d³) factorisations: slogdet + solve). Approximately 1.5× E-step speedup.
Precomputed labeled scatter. X_pos^T X_pos and X_pos.sum(0) are computed once outside the EM loop. The centered scatter is recovered via:
(X − μ)ᵀ(X − μ) = XᵀX − outer(Σx, μ) − outer(μ, Σx) + N·outer(μ, μ)
Note: the simpler form XᵀX − N·outer(μ,μ) only holds when μ = x̄. In EM, μ_new includes unlabeled contributions, so the full identity is required.
Analytic Fisher for A(0). The production implementation uses the closed-form mean-block Fisher:
F_μ = Σ⁻¹ · S_sample · Σ⁻¹
Cost: O(N·d²) vs O(N·d⁴) for numerical Jacobians. At d=50 this is a ~5000× reduction and eliminates floating-point error that degrades diagnostic accuracy at high dimension.
Package structure
semi_supervised_gmm/
_params.py GMMParams dataclass
_em.py Pure NumPy EM loop, E/M-step, posterior
_lambda.py grid_search_lambda, gradient_lambda
_diagnostics.py alignment_A0, score_residual_g0 — analytic Fisher
_data.py encode_labels: y=-1 sentinel → (X_pos, X_neg, X_u)
_base.py BaseGMM: sklearn mixins, predict*, score, persistence
_estimators.py Four estimator classes
exceptions.py ConvergenceWarning, InsufficientLabeledDataError
Running tests
python -m pytest tests/ -q
61 tests: unit tests for EM numerics, data encoding, and diagnostics; integration tests for all four estimators including numerical agreement with the reference implementation, GridSearchCV compatibility, and save/load roundtrip.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file semi_supervised_gmm-0.2.0.tar.gz.
File metadata
- Download URL: semi_supervised_gmm-0.2.0.tar.gz
- Upload date:
- Size: 42.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2ab3b0cf377c1da3c1c71e55cdcde0f27ebcbc1cdd0779644f419a2d5f3e6277
|
|
| MD5 |
d3f41041b4bf16f3eb1d6fed73f840cd
|
|
| BLAKE2b-256 |
39222102025aa30f644adb17737e76f91d01141fd29381b95cddce877f81037b
|
File details
Details for the file semi_supervised_gmm-0.2.0-py3-none-any.whl.
File metadata
- Download URL: semi_supervised_gmm-0.2.0-py3-none-any.whl
- Upload date:
- Size: 27.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7ffae4ac2151a1ddec7f9788f7d5ed7591a05510a791b6c65419c13b65df17d2
|
|
| MD5 |
136e8eebdb94d991e1475f103a5f1f0e
|
|
| BLAKE2b-256 |
468538ad5d2d987c9b2ae0d39df3191a7cad495203ec6c41788611be1408e5dc
|