Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.16.0.tar.gz (10.4 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.16.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.0-cp314-cp314-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.16.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.0-cp314-cp314-macosx_11_0_arm64.whl (831.9 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.16.0-cp313-cp313-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.16.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.0-cp313-cp313-macosx_11_0_arm64.whl (831.0 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.16.0-cp312-cp312-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.0-cp312-cp312-macosx_11_0_arm64.whl (831.5 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.16.0-cp311-cp311-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.16.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.0-cp311-cp311-macosx_11_0_arm64.whl (833.0 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.16.0-cp310-cp310-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.16.0.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.16.0.tar.gz
  • Upload date:
  • Size: 10.4 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.16.0.tar.gz
Algorithm Hash digest
SHA256 c382eeeee6e384ae818e33b530b2b65776204e83f6387d41f4a77c0f6d972f05
MD5 49ecf3597e2b023f39f03bd980018d01
BLAKE2b-256 d5021cd15ad148a9809860635c70184a2889e7ccf03c08709a75547fecee58a7

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 4fa85c8e8dcd5b6f8cd357d0411f4babe86eb94ff6962354c22c15c1f0e2c4ec
MD5 433402127d47d70f163cfb71c90beef2
BLAKE2b-256 9f1dc441ed9051f093a7deaa689e3ad92398880f9b4a7d2b279e65173bc2b31e

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.0-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.0-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 b6b5d4815918e362e71458fe20779ade2f57f5605fe1f3e3ad3723994bb9150b
MD5 f82fed31ebbb863cd8c6e54012d622de
BLAKE2b-256 c1e9f0714dbe8e4bd0f8392f465a4f1f74464a15e6a64af5df975f07ab6c997e

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 68eb7a1a904c82234f78133b66f7330e9e1eb9b6eeed512cec087b8b93d29ce1
MD5 5d993aaad66786582bcee9ceb5ec3c2d
BLAKE2b-256 a4fdeb125e972271a87c7fc2bb3d58d659aafa41d1b6b97d10710f9c24c88dd3

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 cf7923864ea909006c35314ee78d120c61e17f9cb40435f7245319659cb1015e
MD5 27963309fc0af0c97db44c061cf0eaef
BLAKE2b-256 fed54a3e94f397600c3735b7b790c096534aefd04dd98bd76b29ee657f7fc016

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.0-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.0-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 3db2958515e847dfbb2081eb9c8c911f27d5b980ac7622cef9b4fc2b8e883009
MD5 68e82aea9f21ce2501cf9cf27de84ad7
BLAKE2b-256 2a7480ebf18a293d43cc7a34272eb2d0d79361567057d2081a50e335248c616b

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 bc5fcf3873a7559c7ad140e5082c35837a507d384c7cee82f6e5ce1a0b902b2e
MD5 18883b7bf134389b014e1778cf9358dd
BLAKE2b-256 98cfd6effd114391bdee741bcd71a7a8c659abd10b5674e4eac0c6a8394d8068

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 19aad878c51effe5373a1837c3707df1883d21a627145c259de4fa6372b8c274
MD5 7219b3bd9e74ff4b109e93eae3b665ab
BLAKE2b-256 e2eb3b2a2088cf75f2386ec4cba6410e7ea9461fc1f3ff6e5f34a3c72125bbe1

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.0-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 55e14c01be7e9b91abc0091587b0cf765cb86618dee03f250f34d67e27d16642
MD5 43df321129dc3f56201b510498722fd5
BLAKE2b-256 3b0d48f64c722edfc148c566104056ce175fcc80e24c34c8d07fafca66f4af90

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 6172aedb4fdb1e34cf541f133ae23d763d1157378b5eb035b2ba2551f360bc1b
MD5 f5c44d29f1d26fe29f2c35b2a7c07b4e
BLAKE2b-256 a6f1037525e1402dc3a45505578bcdfffa77075ab21fe8d0b21eab88e6528b64

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 1bd8dd1ecd5e0c75c352772da672e4a893b155733b5d904d1d08c451d6633ffb
MD5 582b41ffcf28482bf622ceb52dddc258
BLAKE2b-256 3fd82bf9d62ac4b85de51039529842d59ede17918dbc35dd3154fcaf410c5267

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.0-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 39f264ed42c5e7dff07bfd0a70f94d222ff8e65fbe740479f8ec57627a749858
MD5 703bbca8442ee66e4e8b7093515ee28a
BLAKE2b-256 50bb53b945bcbb7a36df62a18072a49233b9742f4a8d2b19bd6b78fb1872e1d4

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 a53f3ef3b335e105a8435c5b40549e3396a2ed6062d0cfc8b28871ef214aa55f
MD5 5c788bb5794ec6962096a7117786582a
BLAKE2b-256 1a61dd9ffcb7c1dc7b710cfcc481c327bf6161c2d533aa0fc1d9ff38e21877c7

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 218911c2367dce3d9cb7cd1b09c808dd41435553902691e11ccbd62c35e90a09
MD5 0a4e7e0fa3d7c4751185c49b5aed2903
BLAKE2b-256 05fd09c1c4bd5b31f987ab1a1f58870ee92171c927736be3992a9d3bab9a10a8

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.0-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 faa2dd8f1d818928e1ccdffe2a294ea026008551385685806bd62963755eac7c
MD5 de8199f17f505e8071f8e4998b493ab8
BLAKE2b-256 1d477bc09e52db43e695a5437ac756f15c4da02e52f51f94c6359189a4979a3f

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 111c68c8085ba16d9b710b1174d3581153201c5ec470c188d87262d0c77ddac5
MD5 1c760abf230043cd2e28bdf72a07732e
BLAKE2b-256 fffd98b325665356f2454921bc38cdab0cae6ea03c3d8de83138c2fbe085bcef

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 5bf5a61c889d8e20217d77b3691a7b02c0923b705b341366a279ac2e5bad2971
MD5 104a6384c269b6d69dd73aa5671703fa
BLAKE2b-256 0c48a46ad425046c9b08ba84e31c461b9c61a970f419cee5116efd2b69d8f11c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page