Skip to main content

Semi-supervised Gaussian mixture classifier with a weighted unlabeled likelihood

Project description

semi_supervised_gmm

Semi-supervised Gaussian mixture classification with a weighted unlabeled likelihood.

Companion to the paper:

Semi-Supervised Generative Classification via a Weighted Unlabeled Likelihood (under review, Statistics and Computing)


Overview

Many classification problems have cheap features but expensive labels. This package fits a two-component Gaussian mixture model by maximising a weighted log-likelihood:

J(θ) = ℓ_sup(θ)  +  λ · ℓ_unl(θ)

where ℓ_sup is the supervised log-likelihood over labeled data and ℓ_unl is the unlabeled marginal mixture log-likelihood. The scalar λ controls how much the unlabeled corpus influences the fit. EM yields closed-form updates with a guaranteed monotone increase in the objective at every iteration.

The paper treats λ as the primary object of study rather than a tuning nuisance:

  • λ = 0 recovers the purely supervised MLE.
  • λ > 0 borrows geometric structure from the unlabeled distribution.
  • Pre-fitting diagnostic A(0): a computable score whose sign indicates whether unlabeled data push the estimator in a helpful or harmful direction — before any semi-supervised fitting.

Installation

pip install semi-supervised-gmm

Or from source:

git clone https://github.com/aaronjdanielson/semi_supervised_gmm
cd semi_supervised_gmm
pip install -e .

Dependencies: numpy, scipy, scikit-learn


Quick start

import numpy as np
from semi_supervised_gmm import SemiSupervisedGMM, make_semi_supervised

rng = np.random.default_rng(0)

# Generate data: y=1 (positive), y=0 (negative), y=-1 (unlabeled)
X_pos = rng.multivariate_normal([2, 0], np.eye(2), 30)
X_neg = rng.multivariate_normal([-2, 0], np.eye(2), 30)
X_u   = rng.multivariate_normal([0, 0], np.eye(2), 300)

X, y = make_semi_supervised(
    np.vstack([X_pos, X_neg]),
    np.array([1]*30 + [0]*30),
    X_u,
)

# Fit
model = SemiSupervisedGMM(lambda_=1.0).fit(X, y)

# Predict
proba  = model.predict_proba(X_test)   # shape (n, 2): [P(y=0|x), P(y=1|x)]
labels = model.predict(X_test)
auc    = model.score(X_test, y_test)   # AUROC

Replicating paper results

All paper figures and tables are generated by three scripts in replication/. Run them in order:

# Step 1: run simulation experiments and cache results (~10–30 min)
python replication/make_simulation_data.py

# Step 2: generate simulation figures and tables from cache (fast)
python replication/run_simulations_for_paper.py

# Step 3: run UCI benchmark and generate empirical figures and tables
#         (~20–60 min depending on hardware; uses joblib parallelism)
python replication/run_empirical_for_paper.py

All outputs land in papers/figures/ and papers/tables/.

Simulation outputs (run_simulations_for_paper.py)

File Content
papers/figures/fig_diagnostic.pdf Alignment coefficient diagnostic validity (§Simulations)
papers/figures/fig_bias_variance.pdf Bias–variance tradeoff as a function of λ
papers/figures/fig_decision.pdf Use/discard decision accuracy
papers/figures/fig_conf_ablation.pdf Confidence-weighted ablation
papers/figures/fig_dimension_sweep.pdf High-dimensional geometry sweep
papers/figures/fig_alignment_covariance.pdf Alignment under covariance misspecification
papers/tables/tab_misspec.tex Misspecification results (Student-t, skew-normal)
papers/tables/tab_lambda_learn.tex λ-learning gradient vs grid-search comparison
papers/tables/tab_decision.tex Decision accuracy table
papers/tables/tab_regime.tex Regime-grid summary
papers/tables/tab_dimension.tex High-dimensional geometry results

Empirical outputs (run_empirical_for_paper.py)

File Content
papers/figures/fig_peff_scatter.pdf SSL gain vs N_lab/N_unlab ratio, 13 datasets
papers/figures/fig_auroc_curves.pdf AUROC vs label budget, all 13 datasets
papers/figures/fig_distribution_grid_semi.pdf Per-dataset distribution, Semi(λ=1)
papers/figures/fig_distribution_grid_learned.pdf Per-dataset distribution, Learned-λ
papers/figures/fig_distribution_grid_local.pdf Per-dataset distribution, Local-λ
papers/tables/tab_uci_benchmark.tex Primary 9-dataset benchmark table
papers/tables/tab_uci_expanded.tex Expanded 13-dataset table with failure-mode annotations

Incremental caching

run_empirical_for_paper.py saves results incrementally to replication/cache/uci_distribution_methods_cache.csv after each dataset completes. If the run is interrupted, restart with the same command — completed datasets are skipped automatically.

Dependencies for replication

pip install -e ".[dev]"
pip install pandas matplotlib ucimlrepo

Four estimators

All estimators follow the sklearn interface: fit / predict / predict_proba / score. Use y = -1 as the unlabeled sentinel, matching sklearn.semi_supervised.LabelPropagation.

SupervisedGMM

Purely supervised MLE (λ = 0). Ignores unlabeled observations.

from semi_supervised_gmm import SupervisedGMM

model = SupervisedGMM().fit(X, y)

SemiSupervisedGMM

Fixed global λ. Grid-search on a validation set via fit_cv.

from semi_supervised_gmm import SemiSupervisedGMM

# Fixed lambda
model = SemiSupervisedGMM(lambda_=2.0).fit(X, y)

# Grid-searched lambda
model = SemiSupervisedGMM().fit_cv(X, y, X_val, y_val,
                                   lam_grid=np.logspace(-2, 2, 20))
print(model.lambda_used_)   # selected value

Compatible with GridSearchCV:

from sklearn.model_selection import GridSearchCV

gs = GridSearchCV(SemiSupervisedGMM(),
                  {"lambda_": [0.1, 0.5, 1.0, 2.0, 5.0]},
                  scoring="accuracy")
gs.fit(X_labeled, y_labeled)

LearnedLambdaGMM

Learns λ by gradient ascent on validation log-likelihood (IFT-based, as described in the paper).

from semi_supervised_gmm import LearnedLambdaGMM

model = LearnedLambdaGMM(lambda_init=1.0, n_steps=10).fit(
    X, y, X_val=X_val, y_val=y_val
)
print(model.lambda_)   # learned value

LocalLambdaGMM

Per-point confidence-weighted λ: λ(x) = λ · max(γ(x), 1−γ(x))^α. Downweights ambiguous unlabeled points.

from semi_supervised_gmm import LocalLambdaGMM

model = LocalLambdaGMM(lambda_=1.0, alpha=1.0).fit(X, y)

Pre-fitting diagnostic: should you use unlabeled data?

The alignment coefficient A(0) is computable before any semi-supervised fitting. Its sign indicates the direction in which unlabeled data will push the estimator.

from semi_supervised_gmm import SupervisedGMM, SemiSupervisedGMM, alignment_score, alignment_reliability

# Step 1 (optional): check reliability before computing A(0)
rel = alignment_reliability(n_lab=n_lab, d=X_train.shape[1], covariance_type="full")
# rel["reliability_regime"]  ->  "reliable" | "marginal" | "unreliable"
# rel["covariance_recommendation"]  ->  e.g. "diagonal"

# Step 2: compute alignment coefficient
result = alignment_score(X_train, y_train, X_val, y_val)
# {
#   "A0":                        float,  # alignment coefficient (nan if rank-deficient)
#   "g0_norm":                   float,  # ||g_0|| score residual norm
#   "recommendation":            str,    # "use" | "discard" | "unreliable"
#                                        # "use_with_caution" | "discard_with_caution"
#   "n_unlabeled":               int,
#   "p_eff":                     int,    # effective parameter count
#   "reliability_ratio":         float,  # p_eff / n_lab
#   "reliability_regime":        str,    # "reliable" | "marginal" | "unreliable"
#   "covariance_recommendation": str,    # suggested covariance_type (or None)
# }

if result["recommendation"] == "use":
    model = SemiSupervisedGMM(lambda_=1.0).fit(X_train, y_train)
else:
    model = SupervisedGMM().fit(X_train, y_train)

Interpretation:

  • A(0) > 0 — unlabeled geometry is aligned with the classification task; consider semi-supervised learning
  • A(0) < 0 — unlabeled geometry is misaligned; prefer supervised MLE
  • g0_norm large (> 15) — Gaussian assumption likely violated; treat the diagnostic as unreliable

A(0) is a partial diagnostic signal, not a decision rule. Its reliability degrades as the ratio of effective parameter dimension to labeled sample size grows. The paper identifies three structural failure modes where semi-supervised learning is inadvisable regardless of the A(0) sign: misalignment of unlabeled geometry with the task, estimator instability when p_eff/N_lab is large, and prior mismatch under class imbalance.

Reliability regimes (based on p_eff / N_lab):

p_eff / N_lab Regime Interpretation
< 10 "reliable" Fisher matrix well-conditioned; A(0) informative
10–65 "marginal" Use shrinkage covariance ("ledoit_wolf"); treat A(0) as advisory
> 65 "unreliable" n_lab ≪ d; A(0) returns nan; prefer supervised-only or diagonal covariance

Effective parameter counts: p_eff = 1 + 2d + d(d+1) (full covariance) or 1 + 4d (diagonal). Example: Sonar (d=60, full) → p_eff = 3781; at n_lab = 20, ratio = 189 → unreliable.


Benchmark summary

Results from the 13-dataset benchmark (250 seeds, covariance_type="ledoit_wolf", StandardScaler). See papers/tables/tab_uci_expanded.tex for the full table with failure-mode annotations.

When it helps: In the label-scarce regime, semi-supervised learning consistently improves AUROC on near-Gaussian datasets (Breast Cancer, Ionosphere, Heart Disease, Diabetes, Wine) by +0.03 to +0.13. The choice of variant matters little at small label budgets.

When it doesn't: Three failure modes are documented in the paper:

  • Misalignment (Sonar, d=60): unlabeled geometry does not align with the boundary; ‖g₀‖ is large and correctly flags the diagnostic as unreliable.
  • Estimability (SPECTF, d=44, small N): rank-deficient supervised baseline; gains at very small N_lab do not persist as N_lab grows.
  • Prior mismatch (Shuttle, 84% class imbalance): fixed λ=1 degrades AUROC by −0.44; learned-λ partially recovers by shrinking λ toward zero.

Covariance options

All estimators accept covariance_type:

Value When to use
"full" (default) N ≫ d; unrestricted covariance
"diag" d > N or when features are approximately independent
"ledoit_wolf" d ≈ N; shrinkage toward scaled identity (recommended for real data)
model = SemiSupervisedGMM(lambda_=1.0, covariance_type="ledoit_wolf").fit(X, y)

Convenience: split-data interface

All estimators expose fit_semi for callers who have labeled and unlabeled data in separate arrays:

model = SemiSupervisedGMM(lambda_=1.0).fit_semi(X_labeled, y_labeled, X_unlabeled)

Persistence

model.save("model.pkl")
loaded = SemiSupervisedGMM.load("model.pkl")

Algorithmic notes

Cholesky caching. Each E-step uses a single Cholesky factorisation per class per iteration (the reference performs two O(d³) factorisations: slogdet + solve). Approximately 1.5× E-step speedup.

Precomputed labeled scatter. X_pos^T X_pos and X_pos.sum(0) are computed once outside the EM loop. The centered scatter is recovered via:

(X − μ)ᵀ(X − μ) = XᵀX − outer(Σx, μ) − outer(μ, Σx) + N·outer(μ, μ)

Note: the simpler form XᵀX − N·outer(μ,μ) only holds when μ = x̄. In EM, μ_new includes unlabeled contributions, so the full identity is required.

Analytic Fisher for A(0). The production implementation uses the closed-form mean-block Fisher:

F_μ = Σ⁻¹ · S_sample · Σ⁻¹

Cost: O(N·d²) vs O(N·d⁴) for numerical Jacobians. At d=50 this is a ~5000× reduction and eliminates floating-point error that degrades diagnostic accuracy at high dimension.


Package structure

semi_supervised_gmm/
    _params.py        GMMParams dataclass
    _em.py            Pure NumPy EM loop, E/M-step, posterior
    _lambda.py        grid_search_lambda, gradient_lambda
    _diagnostics.py   alignment_A0, score_residual_g0 — analytic Fisher
    _data.py          encode_labels: y=-1 sentinel → (X_pos, X_neg, X_u)
    _base.py          BaseGMM: sklearn mixins, predict*, score, persistence
    _estimators.py    Four estimator classes
    exceptions.py     ConvergenceWarning, InsufficientLabeledDataError

Running tests

python -m pytest tests/ -q

61 tests: unit tests for EM numerics, data encoding, and diagnostics; integration tests for all four estimators including numerical agreement with the reference implementation, GridSearchCV compatibility, and save/load roundtrip.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

semi_supervised_gmm-0.2.0.tar.gz (42.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

semi_supervised_gmm-0.2.0-py3-none-any.whl (27.6 kB view details)

Uploaded Python 3

File details

Details for the file semi_supervised_gmm-0.2.0.tar.gz.

File metadata

  • Download URL: semi_supervised_gmm-0.2.0.tar.gz
  • Upload date:
  • Size: 42.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for semi_supervised_gmm-0.2.0.tar.gz
Algorithm Hash digest
SHA256 2ab3b0cf377c1da3c1c71e55cdcde0f27ebcbc1cdd0779644f419a2d5f3e6277
MD5 d3f41041b4bf16f3eb1d6fed73f840cd
BLAKE2b-256 39222102025aa30f644adb17737e76f91d01141fd29381b95cddce877f81037b

See more details on using hashes here.

File details

Details for the file semi_supervised_gmm-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for semi_supervised_gmm-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7ffae4ac2151a1ddec7f9788f7d5ed7591a05510a791b6c65419c13b65df17d2
MD5 136e8eebdb94d991e1475f103a5f1f0e
BLAKE2b-256 468538ad5d2d987c9b2ae0d39df3191a7cad495203ec6c41788611be1408e5dc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page