Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss; OOS presets documented in docs/qgam_oos_presets.md

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.18.0.tar.gz (13.7 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.18.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.18.0-cp314-cp314-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.18.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.18.0-cp314-cp314-macosx_11_0_arm64.whl (892.2 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.18.0-cp313-cp313-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.18.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.18.0-cp313-cp313-macosx_11_0_arm64.whl (891.4 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.18.0-cp312-cp312-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.18.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.18.0-cp312-cp312-macosx_11_0_arm64.whl (891.8 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.18.0-cp311-cp311-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.18.0-cp311-cp311-macosx_11_0_arm64.whl (893.4 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.18.0-cp310-cp310-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.18.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.18.0.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.18.0.tar.gz
  • Upload date:
  • Size: 13.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.18.0.tar.gz
Algorithm Hash digest
SHA256 8c6c0f98839d93306da9f4c06f489c24da5d7a8404434c559be3d5e87902f40d
MD5 649e12b0c620fb934509ad1b8792317d
BLAKE2b-256 5c6670d2309146f9b723efd44f77686bd78463b0fbc800dfc903d30c503f4874

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.18.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.18.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 3c8a2c7a8c848ee84455652d88068ffcb6bf47c797e58fe4f4836f68c433ff22
MD5 5caada1c1c67f38c7a4db41b97a750d3
BLAKE2b-256 025ca0ec832478872a9de24333a6b1b3dde665134ee142779f8806604efd200d

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.18.0-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.18.0-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 2f15fbc12c512b9fdd9f0a8f51810dfb9bdbee26b6d4fabded68c45b760afbe2
MD5 6088645e899703a1d283c7ed083df2f6
BLAKE2b-256 f2a3197585f264092b02278681f082693d9cf7a1347c70a49f55fabb599b94eb

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.18.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.18.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 c8fc50b14bf4a6ce48d4fcd21db02e8a7c51e77c218e3ffffd466caab885a69e
MD5 900fbf1011581ce8faa09546053feb6d
BLAKE2b-256 396d2aa088195550cd518dac197c41e24dc0f77ce409ad59ce02c78d080d419c

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.18.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.18.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 81bf4dd86085c55295b34a8d3186b09002a2f095d59d4e1d43c84494100a3083
MD5 f2500e6cf2b63385e846c61b9d468ace
BLAKE2b-256 175e207c1c784c398e4e6efa6d17ca32f93df22709aa81a6e6b39fb2f47293e2

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.18.0-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.18.0-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 02b1f8e98021f95b333cab7d89da5dc7242a0df415edfe98ec60e1a3850b0adf
MD5 1aea2bd3c83c6f63ba3898fc8f52211a
BLAKE2b-256 01ab520b83b0206605ec9e35252ea11a6301eb3940aef4816e3334943d19ad77

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.18.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.18.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 500e3048ff6eaf281d19b25801d8bf6088cdea5f248ab6cf7b5ea8cfcddd4bc6
MD5 28a88c21db51f96abdd94f31581140b3
BLAKE2b-256 62e0f75d7c078ec4c437f09e32d00a7ebd925c9a7ec6418ed9ffc4f437a4f124

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.18.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.18.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 f9ff6cb60613c79fb8dacbd709a57edf46e89166a51d3b6682e831a006baec74
MD5 7f3cecf6874e2bfcf87fcbfb690e1e8a
BLAKE2b-256 f5f9f042a71fa18601a5cf0f8bfdda45aea8ec227d14a5629d148500859d1f71

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.18.0-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.18.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 b4343f4db2b783e69c506e7acac2787bbf5f05ce3f73d96687c1e3788666947d
MD5 26bf7e8624b4046358e4b80cb4c45402
BLAKE2b-256 7c461cd7f0d9160b9f68de978c14f561dab14a4efdf3535136bfd49ed408176d

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.18.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.18.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 f1799d949cc45a4b7a55fc590e6baf463ce81898fd8e8424f63e62a9c1ef963f
MD5 bdc94f4b1ecb1101ac216b7229b0fba3
BLAKE2b-256 d426076ce48b47edcfbfef35d34a36baf2ec8e5d8de8d0aa442de7c68a619690

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.18.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.18.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 aa33f344301d76b358a30b6fcb15d40a645c3d2ba10e69e163e01bf8559e4163
MD5 0d6022ab192d7198ea64f6e74d7c63ad
BLAKE2b-256 31bdeb60c5e72b3e91b281f2129abe92065988b518a5e0437f13069f306e17ab

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.18.0-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.18.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 06b7be033af5a5fa960dc6be7ae95eeafe5ee1a1d96825a4da56c8a5ed6d8052
MD5 9ca46d19803b64a825e86120b75592b5
BLAKE2b-256 7602c71d08583e31691abe6adca44130f62367fee8556651e953c03fdeca0045

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 5ae30aa32d16e6359fba89e1e16990e10a36f34c804c03ed1c0f39b663b1c5f8
MD5 44686c471b9962c287cfaf7b2c6150bb
BLAKE2b-256 f578236b18020f4acd6b06cdf6c283881c949091a9d1f9cad7ff10ff7b8e81df

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.18.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.18.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 b1060379fcbc5e0b8a3682a90710252c0c682eb0865113d47d50c01652a1675a
MD5 f34414578dd284c2e1b83fd239d01818
BLAKE2b-256 3c3fa5cad846c55ae645aefcd9eb03efb636cba99326ead646e0ac3b20ede0de

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.18.0-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.18.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 f99446c9e7a143ac35e79de4309b1fac12bcfce57705e1b3aac72ea3d9678221
MD5 0822e296688cf3acd086812af7216379
BLAKE2b-256 192b7f78605d597ceaa2eff547585755a76cceae60f75b691310069516950906

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 6d45e809119dfa1ef8fe478e6758f443af980b9bf59b7fc80d1a0357dc904dd5
MD5 04c82bd4e7dc524ca10addaca1701034
BLAKE2b-256 e03d27cf4f65405d73025284515a2311febba35b296675ea07cbb60b9b78179c

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.18.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.18.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 d891530829e46cdad57ab50cda608898226df290ad02d79e83e82157f575ff3c
MD5 22c57474be0b0784f142f1fde8c0c13c
BLAKE2b-256 f76014defa834c6456ef3e14017fb74a652903e992a778debe01d2a6b9711251

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page