Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss; OOS presets documented in docs/qgam_oos_presets.md

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.22.0.tar.gz (14.0 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.22.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.0-cp314-cp314-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.22.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.0-cp314-cp314-macosx_11_0_arm64.whl (931.7 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.22.0-cp313-cp313-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.22.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.0-cp313-cp313-macosx_11_0_arm64.whl (930.0 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.22.0-cp312-cp312-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.22.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.0-cp312-cp312-macosx_11_0_arm64.whl (930.4 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.22.0-cp311-cp311-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.22.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.0-cp311-cp311-macosx_11_0_arm64.whl (932.1 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.22.0-cp310-cp310-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.22.0.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.22.0.tar.gz
  • Upload date:
  • Size: 14.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.22.0.tar.gz
Algorithm Hash digest
SHA256 5f2b681c84d54a4b9b334d69f0495a8d21c89db7676ae0cdfd74b05168adb07d
MD5 32b33bbd4c497283f0e9cc0f926a8dbb
BLAKE2b-256 c9231fff42baef99267b2a999411f1935a75b8ad55062ab07d8e6c1f3f451790

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 0a6f0e1aa66368b615ac7f0231f8018fa11208a201440bd9b6ee97a8b5a0c916
MD5 0be94ec30b5893a4f7128e2fbf335f96
BLAKE2b-256 eb70a5a7c35c72f08b93d5965d32564333c72a893db275a81e2209e587ad2f2d

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.0-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.0-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 7a2b619712599e2f9e54a33049bd44d724d4d69a61ee50205717c28b307f47e9
MD5 68996b5f344195ead2cc5cc86d1e3d5b
BLAKE2b-256 6eb1e39e025ef65c9c1ed872e520b8fe7f92545e8c070d867d4c97e4f00aa70b

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 a47011ece77c0b7fdebc2e54bfb0647dca8a0333d03a2f058a8ed202b7f34538
MD5 6e391812985806ffa0b9b73e49bdc02d
BLAKE2b-256 3edda390a18fe3b7d88863e1b9a8c7ddd764a92e49eedd8d50a49dd1db2d7453

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 4be6da7a8ceb33371fdcb3f033804467b6d8442ed06dc5d41c509d4ebfcbd096
MD5 3b3029bdc035582e48f09a4f70e7c7ed
BLAKE2b-256 0ea107d4bd02878e6c22be85ad41311f2d024c1d5397fdb6ac5cef6b2bd598d6

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.0-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.0-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 281a534955f5f66fd1fd047ee4868af044d92162a6e9278eec2ed72b6665afb5
MD5 37edcd31d5b42c66f8914db33a0b2ebe
BLAKE2b-256 2eac902b27d575d28982ef4dbe7bba10dd9c53e994f230940af9b6fd020bbeac

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 015f74de36dbe0dc4db9bbb100dc9d87308a2f82eccebc17c2d18164661400f2
MD5 3a0915fe612390e7989254a57afb25d3
BLAKE2b-256 570860f18ff425f28baf0a2c1cdbc19259082287ae11d2f35e22d829cabaad79

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 b83e216cfc122211f644f078bf30d957fbb0d5f294452ac4c8ccc8dff30dd758
MD5 1a77b03a91f2109fa61f24dc6ea70dae
BLAKE2b-256 5ae24bbdfb55da2f3be671a9a89710af653e9caa5aeaae87348ef9c088aa0237

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.0-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 7d17c4b4dd3d45026a98e6f2217bb67d4e2c799c788f61530fc27037c46665fa
MD5 3cc7d2fa2bd06f958fd79c50678636fd
BLAKE2b-256 931b8387b7642a8203d1c834bfc45115676cb8f30ae85eb444a78d52b59e3a50

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 656692a0df540aa494ddbfa8af26d32a8845e74d2fdcedf19bb065d470473404
MD5 33109229eb833a22ea4be0ea4d7f742f
BLAKE2b-256 e72dc0c5d3b3ff61da925c22afcdfcceb39f80b4e4b96cddc3c2105c68539e3f

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 c741fd7d56bd358aa053c769780bcc97349e34284dbfbeef3f47fa1f9b569cfd
MD5 1f1d2d1bd692fbdfdd95447901dba3a7
BLAKE2b-256 dc8a49376ca62607d1be79b2032a4434a509f3772f2b66d90f8c7dae8431dc99

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.0-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 2d0fc90328e6de0656502497a86044e794aa95f62d0543758d39812c51761712
MD5 5a9cb4203958e10c3942df897c7bc665
BLAKE2b-256 c1d6d324cdc89c8f4271033b13b9050c81fe8089b570b22f0405014c9ac57b11

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 4c1cfc337fdb5efcd2e71e01c993e92e0cbd00fd897a7bc67bad9eb2ea4519b7
MD5 21842f5250c51679abb127ce8a3fb433
BLAKE2b-256 a44be1aeb6b8ffc70bf63dba5d2250657f437a9835bdb6c1eb2b3611a6634451

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 2dd4ba219a9af57c7dd5767b838c1a78319fe6c75bc8e80bcbc4868ee1e87564
MD5 36208134c49a0bbea4e99e2e36b7b6f7
BLAKE2b-256 7dde6c49e75537895f4580b077d13f4039c001737affa41e14ad439462c4c084

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.0-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 e6867c323cf5693402bf8cdbf6e5d8f723dbdaa35dfd10ec8fe0d32627b16f74
MD5 c60da7db00daeae90a5e3c256dbde294
BLAKE2b-256 abd46517ad88d6f42b7496a7ee8920a5c432317f2e6f60bdfdf17816ff2e0244

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 1086af79ac9135133bf4556d57edede4b472f43fe9f27c269224af164e212460
MD5 e6f67d4dc3b1d5a1396884d95c940233
BLAKE2b-256 19300d1552f927b219fe542a6441274bc6c17b73301e4214e867f7e74a740b71

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 4fe5773710cfc06b180dd28c3a613fd2dc80c5e1d4b77a2f18b73067c739421c
MD5 fd5dc976752e5a6b4127824ff2d43a5d
BLAKE2b-256 cf4cb106f46b8b5c0cb0c663d71d43e2d911deadc9196280490e5edb07831524

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page