Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.16.1.tar.gz (11.8 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.16.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.1-cp314-cp314-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.16.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.1-cp314-cp314-macosx_11_0_arm64.whl (860.3 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.16.1-cp313-cp313-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.16.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.1-cp313-cp313-macosx_11_0_arm64.whl (859.3 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.16.1-cp312-cp312-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.16.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.1-cp312-cp312-macosx_11_0_arm64.whl (860.0 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.16.1-cp311-cp311-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.16.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.1-cp311-cp311-macosx_11_0_arm64.whl (861.9 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.16.1-cp310-cp310-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.16.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.16.1.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.16.1.tar.gz
  • Upload date:
  • Size: 11.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.16.1.tar.gz
Algorithm Hash digest
SHA256 ec06174edc71f2a70bc5fe7a35a9400e067380d84e628824aa928c4cfa32cfda
MD5 014e4212f196420c0b95ec9b3742061d
BLAKE2b-256 dc95711c88d5786cbe7e093fa3855599a32e2a82955b8519dbd906ae30c5f3e6

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 9fbbe852df56dd77e15219f1adc95a9b86676e40967180b31de248933f20b423
MD5 d8211aed0d3ed0ce08d714dc848d446f
BLAKE2b-256 a1f947965b53e9437688ec6e1540fbb355ef588731a5b7dc40501217825b6770

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.1-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.1-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 b433a7ffbf7d266285ffadce4e05f56387d0d1726703f46f23a54825e64a9c5d
MD5 3853ee0de98844c19ee71c74fc77e9fc
BLAKE2b-256 254687bb94e22b194f91368f79fdf74e0e10e04994a18396c48acaa296d1282b

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 2663f94251b3bcda8f8da870e5cb5183ba06822f948a45cade1aeabd2afc7ef1
MD5 863aa39cd3cb5e89d3f06870ff44ba02
BLAKE2b-256 febd03e774b9204e50b3858bbbb66c832660048a6c1643f2ac94ab818461aed4

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.1-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.1-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 14ca58681f55821c81539fbc9e2c330f2c11305a13452e3b5c8a0e9f3057a60b
MD5 2998d42e61506e76ae7d8e94c7e03db4
BLAKE2b-256 9c551180ffe68df846e943e97fc71475be96a2ed645b00ee342f865f6a64dd51

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.1-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.1-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 a3672bad2ccdbf0fcd38251bc7479ce10c68fe0e3a818205e8b661458e0db128
MD5 a1b08c96ecb3dc0a06153b602b605e8d
BLAKE2b-256 c6b97e326402e69b15a7dc09988368fad150b93f67956b4aa3921bf77de6ad3a

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 0f2b1c5627c32e4d85f40fd125a605b04e1c6b8b67561b7013a01857b5048dda
MD5 c39ad99d53393f9aac0ff8e83fc862a6
BLAKE2b-256 67efecb43ccc59e137faeb1530d46a8eaa8291e7ada8fe1135d1865787f04c80

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.1-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.1-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 8be38251707ad7d6f341b0d6119e1bf8766cf3f32718f5dae3db8a3d641c9f45
MD5 d5a48da601306e905aec416fd842f168
BLAKE2b-256 c9f332964bf3081ac3eb1bb637b3cfac4bed53a5fa2ae54c288b165a3ee7d99d

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.1-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.1-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 69aa75001706f992a16d5c6b85ed47ad5999d1a02ede41b31f583af96f94b7be
MD5 6fe47f91d7bf11b33a4ce65d65449cbc
BLAKE2b-256 b3d5575260c6740385b70a280fb9d9ccf692fc39425371ea10ea2ebc2b8a5184

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 0d6a18651cb85990b26494d54f072b118bf9a245ce9727e496653c0e10e5d3f6
MD5 b57a80da3cea74258f063702152ae1ef
BLAKE2b-256 634c482bb1e17a4401d1624cb47f5a8aa32a31c51c034bc2ee8cd17b1ba87e9c

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.1-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.1-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 202e7c399febf3ebf940b029b1bb1f71b3132f77bd8e73b8a77705e6b213c606
MD5 2f823ea753db3429f665655aec9b71b2
BLAKE2b-256 d8ea887a77e038540a50235494847cc6defba75004b2b694ffc1e876454376c3

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.1-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.1-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 369c9d400d3ed38065004f47ec763866e6c089c087862b7cc230bf0b3ae87261
MD5 6db43309bfba5d743e48b25f8a091603
BLAKE2b-256 d4141e8ce949ff965a10fa6ade754196e16bdcd3d942d2bed75a6c9c322988f4

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 13ed865ba5d8d9a2dc62a259cd4f51e1a6f65ee05d36796c7e2bbd04cd18545f
MD5 e57624f262a74ae548a1b774bb774233
BLAKE2b-256 62cf0d2e080b84956b729fc5f429b9beaa32dda7cf23dc84c56d503ad9e31c8c

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.1-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.1-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 6eba3078d82fe93d3740603f900357a8546e27046028379db7d440afa8e5aab0
MD5 6c16d63ac41fc92c418a48199137bd73
BLAKE2b-256 d4271fc5e82e7c2bc7c4a06e5f063266136f35cdbd79082e5a51dab5d8272bd2

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.1-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.1-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 afded725a8f233dd9b915ae0418198457b6b28bf097bb1f7c8a0bf3470c0b4d6
MD5 c27f67bc71208546d603230822eede16
BLAKE2b-256 ae8b9306207dbc8e62660735734b84bfbc3c904880f7e7d07f606eccdb6cc2ce

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 d0c4dca4cb4276addb97f6798a45d57b3ecf99406c961b2024552a079837404b
MD5 1b33b1638a4160f8dfb620628f7e0915
BLAKE2b-256 84e2910d9affc66ab48c558fbc9dddf969e0c6a4b6a764053fb838daec4e2942

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 38bafcbe246a1fc986ecb06d46612923c8d08399d56c6973fba3f10732e55c4f
MD5 7cc4cd6d90759124884cffe822741644
BLAKE2b-256 56b17c6ebcbef7b6444094e7a757e2bd99240c4dd49cc34aa9f182a29e513db2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page