Generalized Additive Models fit via ADMM
Project description
gamdist
A modular toolkit for the GLM/GAM zoo — binary, continuous, or count outcomes; continuous, categorical, or spline-transformed features; arbitrary regularization (ridge, L1, group lasso, network lasso, network ridge, curvature penalties) attached to whichever terms want it. Every model is a single convex optimization problem.
The joint Hessian for that problem is intractable to derive and solve as a monolith. The ADMM decomposition of Chu, Keshavarz, Boyd makes the zoo tractable by splitting the problem into a per-feature primal step plus a per-outcome proximal step coordinated by dual variables. Parallelism is a side effect; the real prize is modularity — outcomes, features, and regularizers are independent components that mix and match in any combination, with new ones plugging in without disturbing the rest.
Supported families: normal, binomial, poisson, gamma,
exponential (= gamma with dispersion 1), inverse_gaussian.
Supported links: identity, logistic, probit,
complementary_log_log, log, reciprocal, reciprocal_squared.
Feature types and regularization
Three feature types are available, each with its own set of penalties applied inside the per-feature ADMM step:
linear— continuous feature with a single coefficient. Supports ridge (l2).categorical— per-level offset for a categorical feature. Supportsl1,l2, group lasso (group_lasso) for variable selection, network lasso (network_lasso) for clustering connected categories to identical coefficients, and network ridge (network_ridge) for smoothly shrinking connected categories toward each other.spline— cubic regression spline with an integrated curvature penalty. Smoothing is set viarel_dof, the target effective degrees of freedom.
The network lasso is a good illustration of why the modular design
matters. Pass an edges DataFrame describing which categories should
have similar coefficients (neighboring counties, related products,
friends in a social graph), and the categorical feature's optimization
step adds an L1 penalty on the edge differences. No other component of
the model needs to change.
Install
Requires Python 3.11+.
pip install gamdist-py
For development, clone the repo and use uv:
uv sync # runtime deps
uv sync --extra dev # plus pytest, mypy, ruff
Quickstart
import numpy as np
import pandas as pd
from gamdist import GAM
X = pd.DataFrame(
{
"purchases": np.random.choice([0, 3, 10, 16], size=1000),
"gender": np.random.choice(["male", "female"], size=1000),
}
)
y = (
0.1 * np.log1p(X["purchases"].values)
+ np.where(X["gender"].values == "male", 0.1, -0.5)
+ np.random.normal(size=1000) * 0.1
)
mdl = GAM(family="normal")
mdl.add_feature(name="purchases", type="linear", transform=np.log1p)
mdl.add_feature(name="gender", type="categorical")
mdl.fit(X, y)
mdl.summary()
yhat = mdl.predict(X)
Network lasso on a spatial categorical
A second example showing the modular regularization story: 12 regions arranged in a chain, with a true effect that drifts smoothly along the chain. The network lasso shrinks neighboring regions toward identical coefficients without any change to the rest of the model.
import numpy as np
import pandas as pd
from gamdist import GAM
regions = [f"r{i:02d}" for i in range(12)]
true_effect = dict(zip(regions, np.linspace(-1.0, 1.0, len(regions))))
edges = pd.DataFrame(
{"node1": regions[:-1], "node2": regions[1:], "weight": 1.0}
)
n = 2000
X = pd.DataFrame({"region": np.random.choice(regions, size=n)})
y = (
np.array([true_effect[r] for r in X["region"]])
+ np.random.normal(scale=0.3, size=n)
)
mdl = GAM(family="normal")
mdl.add_feature(
name="region",
type="categorical",
regularization={"network_lasso": {"coef": 1.0, "edges": edges}},
)
mdl.fit(X, y)
mdl.summary()
Swap network_lasso for network_ridge on the same edges DataFrame
to get the smooth-shrinkage variant: a quadratic penalty
λ · Σ w_ij · (β_i − β_j)² (= λ · βᵀ L β for the graph Laplacian
L) that pulls neighboring coefficients toward each other instead
of clustering them to identical values.
Multi-task: one feature set, multiple outcomes
MultiTaskGAM fits K outcomes jointly under a shared feature set, with
each task picking its own family and link and an optional coupling
regularizer that ties the K coefficients of each feature together.
The headline use case: identify a feature set that's simultaneously
predictive of every outcome — not a feature set that's brilliant on
some outcomes and useless on others.
import numpy as np
import pandas as pd
from gamdist import MultiTaskGAM
rng = np.random.default_rng(1)
n = 400
signal = rng.normal(size=n)
noise = rng.normal(size=n)
X = pd.DataFrame({"signal": signal, "noise": noise})
y_continuous = 1.5 * signal + 0.1 * rng.normal(size=n)
y_binary = (
rng.uniform(size=n) < 1.0 / (1.0 + np.exp(-0.8 * signal))
).astype(float)
mdl = MultiTaskGAM(families=["normal", "binomial"])
mdl.add_feature(
"signal",
type="linear",
regularization={"group_lasso_across_tasks": {"coef": 0.5}},
)
mdl.add_feature(
"noise",
type="linear",
regularization={"group_lasso_across_tasks": {"coef": 0.5}},
)
mdl.fit([X, X], [y_continuous, y_binary])
yhats = mdl.predict([X, X])
# yhats[0] is the continuous task's mean; yhats[1] is task 2's
# probability in (0, 1). With this λ the `noise` feature has been
# dropped from BOTH tasks simultaneously — group-lasso across tasks
# zeros the entire K-vector of slopes at once, not each task's slope
# independently. That's the multi-task variable-selection story you
# can't get from running K independent GAMs.
Tasks may have different (family, link) pairs, different observation
counts, even different feature data per task — they only have to agree
on the names of the features they share. The per-task ADMM splits
stay independent except for the cross-task penalty inside each shared
feature, so the seam principle (CLAUDE.md) carries over without
changes to the orchestrator.
Other coupling penalties (trace / nuclear norm, network-on-tasks, hierarchical pooling) and other convex ways to combine the per-task losses (weighted sum, log-sum-exp, minimax / CVaR) are tracked as follow-ups in issue #39 and issue #71.
Development
uv run pytest # run the test suite (96 tests)
uv run pytest --cov=gamdist # with coverage
uv run mypy gamdist # type check
uv run ruff check gamdist tests # lint
CI runs all of the above on Python 3.11 and 3.12 (see
.github/workflows/ci.yml).
Extending gamdist
The modular design means new components plug in along well-defined seams without touching the rest of the system:
- New outcome distribution / link — add a proximal operator entry
in
gamdist/proximal_operators.pyfor the(family, link)pair. Nothing on the feature side changes. - New feature type — subclass
_Feature(seegamdist/feature.py) and implement the standard interface (initialize,optimize,compute_dual_tol,num_params,dof,predict,_save,_load). The ADMM loop ingamdist.pydoesn't need to know. - New regularizer — add it inside a feature's
optimizestep (alongside the existing L1 / L2 / group-lasso / network-lasso / curvature terms), scaled bysmoothingininitialize. The global loop never sees a penalty coefficient.
Every per-component subproblem must be convex.
Caveats
- The package uses
picklefor save/load and is not designed for untrusted input. confidence_intervals()is not yet implemented and raisesNotImplementedError.- Convexity of every per-component subproblem is a hard requirement.
Non-convex (family, link) combinations are out of scope; the existing
scipy.optimize.minimize_scalarfallback for non-canonical pairs predates this principle and is scheduled for removal (see issue #19). - Gamma + reciprocal can produce non-positive
muon small datasets (numerical edge case in an otherwise supported combination).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file gamdist_py-0.2.0.tar.gz.
File metadata
- Download URL: gamdist_py-0.2.0.tar.gz
- Upload date:
- Size: 125.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7c1ad01f060b74b97981ac92ee65bca1d35e432bf133cbce7a757be72f88f85e
|
|
| MD5 |
486d371306a444a9629cf03577ae58fb
|
|
| BLAKE2b-256 |
532b487c25a25d39e3f1696932785b4a0747bf10ff2e38457ca5773a7abb1299
|
Provenance
The following attestation bundles were made for gamdist_py-0.2.0.tar.gz:
Publisher:
publish.yml on rwilson4/gamdist
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
gamdist_py-0.2.0.tar.gz -
Subject digest:
7c1ad01f060b74b97981ac92ee65bca1d35e432bf133cbce7a757be72f88f85e - Sigstore transparency entry: 1467603236
- Sigstore integration time:
-
Permalink:
rwilson4/gamdist@b32c0306f8c151728c984995e76e1ca075b742c6 -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/rwilson4
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@b32c0306f8c151728c984995e76e1ca075b742c6 -
Trigger Event:
push
-
Statement type:
File details
Details for the file gamdist_py-0.2.0-py3-none-any.whl.
File metadata
- Download URL: gamdist_py-0.2.0-py3-none-any.whl
- Upload date:
- Size: 75.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
603940de18b2429f447712e88aa8a749c4434957e4fc7ca31153536b0cdd9e47
|
|
| MD5 |
247d4da0b2b9b4904aa5ddf8952ee696
|
|
| BLAKE2b-256 |
8bb49b97bcff9f54f5ae00b15ec0dccbc126c29d1da192765efe0e750cd739e1
|
Provenance
The following attestation bundles were made for gamdist_py-0.2.0-py3-none-any.whl:
Publisher:
publish.yml on rwilson4/gamdist
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
gamdist_py-0.2.0-py3-none-any.whl -
Subject digest:
603940de18b2429f447712e88aa8a749c4434957e4fc7ca31153536b0cdd9e47 - Sigstore transparency entry: 1467603390
- Sigstore integration time:
-
Permalink:
rwilson4/gamdist@b32c0306f8c151728c984995e76e1ca075b742c6 -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/rwilson4
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@b32c0306f8c151728c984995e76e1ca075b742c6 -
Trigger Event:
push
-
Statement type: