Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss; OOS presets documented in docs/qgam_oos_presets.md

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.22.1.tar.gz (14.0 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.22.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.1-cp314-cp314-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.22.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.1-cp314-cp314-macosx_11_0_arm64.whl (936.9 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.22.1-cp313-cp313-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.22.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.1-cp313-cp313-macosx_11_0_arm64.whl (936.0 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.22.1-cp312-cp312-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.22.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.1-cp312-cp312-macosx_11_0_arm64.whl (936.4 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.22.1-cp311-cp311-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.22.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.1-cp311-cp311-macosx_11_0_arm64.whl (937.8 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.22.1-cp310-cp310-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.22.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.22.1.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.22.1.tar.gz
  • Upload date:
  • Size: 14.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.22.1.tar.gz
Algorithm Hash digest
SHA256 b88daf3b45029925be2b12f8a4d211c58a03120002c892f73e6583ca7f32cf46
MD5 2c74977c9ad5669ac62c6c89701d776b
BLAKE2b-256 8007eb3ff745a5e70ae0ec735c9b55d835beec0fc717c9368a90b1b527366b42

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 6c7a8585f3fd7c28fe803258e22e3bb18d76af18904fbf54b13e06c8bffde2a0
MD5 033d2f0805c688fcc4b7f4a2cf2e20da
BLAKE2b-256 4956f3bdaa307a37c073be90ab5faf905df91bf57121b1ef06675e83714a451f

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.1-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.1-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 d66303bfdb4d69836253ac877aff6c0116bb74ad52b38df23703270a299a90e5
MD5 42a0b32c85f55c0d56fedcec3a7ab731
BLAKE2b-256 4094b62639d1f5ee9070b450da8dc4160d74d272a0ef78501c1f166afde588b5

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 8fa61dc1be8b6807dc55996dde281d385809991ea4eaeb870498c335916550fe
MD5 68bd1bc75540a28d88c229562ac0a7e7
BLAKE2b-256 3ba21b71f8ffd7d9e8ec0d0abad9b0a9eef5661033f3c8919eb7e6a91cbcf017

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.1-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.1-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 46437cd6c45d5a8b4433335982ec8fcb37a4b92a2e5e1de8fd4997c5da4ca3a3
MD5 d2af0df822232733e5bc33a57c5a182a
BLAKE2b-256 2f4e4b3908a0ea424b2c367aadcb9d69299f5fd73232ffb68656462c717b2619

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.1-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.1-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 57ac4b047ea1e772dd6918dcb65912845540e9fb34cb93b564cd0babee303e3e
MD5 451f62ff5bccd11da1e5cdb7a4c1b333
BLAKE2b-256 ded3d586cd681fe924f12b7150bd0f9b0de16bc70d4d80ad0e8e36cf8c4a227f

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 87fd5832364ca508244289a482f0c250a0d11401a95b055f1eb5892f6d08ff5e
MD5 421ac0f68d3c41658b96f7f7b25f975c
BLAKE2b-256 330f9e96fdf4c092376e8770cb5c1d4784d36e544bfccb741c67abc3eca19a8d

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.1-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.1-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 f2f34f7b6274cf70d26a592d7ce05b82509cfc5651aeeb294c8b0e043b982e12
MD5 478c7fb3255ca28203bcb108ab061187
BLAKE2b-256 3da57c6515494a0370017269a1763f7c6946a2115258ac9215fc1a1f7a2b06f0

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.1-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.1-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 41ca6f3766f57fe41fa8b09223d789c33dd699ad68f9369e3bc3b20c347b142e
MD5 4b8939870cd8497acc7022fa91afe15e
BLAKE2b-256 20797e2a77ef971db4ff6c98de2a07655621cc9a2c26614a83e3693995715c7e

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 db19f0bdff1fc214fb863cb5dddcbfdf1873bb3755ca5dbbddcfac0a44b7a3fa
MD5 23855c91f62b9e9d8225fd7c6cc6ec8a
BLAKE2b-256 70a127352d9d14ee3e562d34f6f98c86e4f4b61a4884919f7606860b73f68e38

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.1-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.1-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 a15908afe7e967b419812f8dcf072dd3973be2326a38a1e989532e500e6ab7c6
MD5 a792418973472f04fbcf72f01813fcfe
BLAKE2b-256 a326760665feb46483fbf6e37f178092f0cdf8251a9a0f5099439c4f2e074ce8

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.1-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.1-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 c8f54225c5b22388330e0eb367f7b49fa334a2c4b1d12fe1f578d98e0f02f41b
MD5 5d76a12df7620811182e633348a1db31
BLAKE2b-256 be35d3a912c2783ebe3f8207c63f06fa1b9607ce025969be6f3a337066ef13df

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 45d08c3f6c65e403f52bda26d0262959418d616b48c90b6d69e05294ab2795f7
MD5 d0e9801ea95617fe318ad1065e9d1bf3
BLAKE2b-256 40ddb2ed4f997f6c3ec9e83696c46100ee1e525a94a6b7a536819ae0b943d954

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.1-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.1-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 5db52fe608aa12e353b59de7092fc7eb0605a7f7cdfd6fac8c545386985dd6e5
MD5 0872cdc095114c34fdf210712d5190bd
BLAKE2b-256 7d982723b6640e52b1e76f971dafb65508e12e848676afef419a619b1e40f004

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.1-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.1-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 fda81fe4f8fb4ff76ee865a21980d2430b0bf33035fc297fe22801ceebdc85b2
MD5 2b21805c7b419a23302636d8b80fac23
BLAKE2b-256 18f977de3948fb4d55026536752550f633d7d2e2cd193b9bbb06f1c5f3f6cc0a

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 bc9447d1f5779ff4f27b5cb63fc313a28955d2c6dd050026e93118a84d2ab49e
MD5 3592b0eb0f7d3b67368726598a512f16
BLAKE2b-256 3f2e7456dd9ffe270ff411e443eded60d07693c9b32f1970dcdb2537025f1c33

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 df47f9097cf431a3e57c271085bc9d4010af42467572d6d5624fdb427caa85bc
MD5 cae6627777e2806695a56ff75c291b2e
BLAKE2b-256 d8e65cbf5babdbf9f06c697b84668efcffe8bce640fb352db195479aed219654

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page