Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss; OOS presets documented in docs/qgam_oos_presets.md

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.22.2.tar.gz (14.1 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.22.2-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.2-cp314-cp314-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.22.2-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.2-cp314-cp314-macosx_11_0_arm64.whl (938.0 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.22.2-cp313-cp313-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.22.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.2-cp313-cp313-macosx_11_0_arm64.whl (937.1 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.22.2-cp312-cp312-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.22.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.2-cp312-cp312-macosx_11_0_arm64.whl (937.6 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.22.2-cp311-cp311-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.22.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.2-cp311-cp311-macosx_11_0_arm64.whl (939.3 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.22.2-cp310-cp310-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.22.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.22.2.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.22.2.tar.gz
  • Upload date:
  • Size: 14.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.22.2.tar.gz
Algorithm Hash digest
SHA256 c31b8589cb7d32c2577920cff6b182dc98e0ec6d414a701c162f6a35b247343c
MD5 3e0c7c75a089d0270d0d960fa2806831
BLAKE2b-256 6a75ad51f8859262c05360b564b6820046137c35a4998a207540d282b970c546

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.2-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.2-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 98fa5ac7c08b9c8b5fa8022e1006dd791c49a4ace18342fbd395cc7c08f74ada
MD5 d5c9bc04fe6bc9d98d89cb8a94033b4d
BLAKE2b-256 34e153c53e9be7b13b6f3a4c41b7e103fc9432c8880aab8738c33633b692c114

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.2-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.2-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 c59fd688d3143e48772a98b8514eeb26b4fcee9cd6d0d60282525a9033af2c27
MD5 b36758c7934c7b9d449a0323869a6c80
BLAKE2b-256 15aad60909408b118236f58ed8810421f803d3a527062faae55c6485b047fb23

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.2-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.2-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 2e6b454cf76e6d5d1f4f6832a620bd41ff6989233eb87962f5cbb1dac9c98501
MD5 c65cf7775616fd384eea0d66d41413a2
BLAKE2b-256 8d4c5e32e56f6cbb894eb4131f9ec3a965e27607bbb728c420a6fda5136b7622

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.2-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.2-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 c4f56ea9c3ff7bb7d4fbfa15931f3767c8954a1f088d01d7cd7e967956dd21d6
MD5 f18ba0222ceba5423a9e20a11ddc1a75
BLAKE2b-256 473990af3d00f561f575e95cc91e4fabc31434852ba647d025a0b559eacca3d5

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.2-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.2-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 b62e389f3c7bfc533a3f131621f261d9b292af183588e0b5c4945ac1001ffeab
MD5 cadd093932648ca87f8093bd6a2e16cc
BLAKE2b-256 34aa3969cfedb0ef4dc06a219933c8a7dab45a94c84e3a127ab38ecd27e31dfd

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 cc88ccddecc77937934f45eb8f5c3dab62f934328e63b41d09b1abba9965dc70
MD5 8aa1ce8cd8d92e9e3da46aeec1ea227b
BLAKE2b-256 6d488fad7f13af34e71147154410fdd7c39f1528f6fa2d4bc33f80aac8c5200b

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.2-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.2-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 036fa787f4e1fe19f4c6645632e58ea05c8a0b96893c9db43077da20a6a0bd85
MD5 0fcc7e49a0219c576039320956306779
BLAKE2b-256 9ad62877039de2a80aaebcdd718e7dc07e36e9ec755bd8f3b6dd44506a4146a1

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.2-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.2-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 146e813178140b14c966098793cb107e75d29c98992f0fc85f8b31285316dd5c
MD5 fe39430e1b0d60f392ea2736a74cf2f2
BLAKE2b-256 d98b57bb5e9cbd1c8c4ada1d52b7fb1a410f3ef813beef0c009492f8c406d73e

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 098c9a0c1a6eca31c9faffbfaf5a3b2a82676fbc0a2dba202c6e88cd4ad4eab9
MD5 759cf7d0114186d9722cefd99e743957
BLAKE2b-256 dd6a53ac3e0b8ed56aebe35e002bedaedd2a71f303af143e4c32c28d585fd2c2

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.2-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.2-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 ac6b5b75a9521d81f099f82d113717ca7e03639e0e3a0170e547929439a846f5
MD5 b6984671e4f9203096f0fe3dc1ff8e11
BLAKE2b-256 8076e716a4171fa8a9f1968bf0ab9b5dd428572824bcce63fdf21ca3512e0258

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.2-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.2-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 a2c1cbf898a1f763a7f418c34e177be314a0723f7c7cb90797341cacc602bc25
MD5 8a1c47d8f37567ccd37be1312a6bfe5c
BLAKE2b-256 3634538debf32c0bc87dedc6d3a0dcb0fda46e8cba7863646a399bd3850f5686

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 620045a8dfc513e6927e59dca0fd482c6e4e931b455a2e31ed36b96d81d09270
MD5 858cb68300c8c79f9a34b103b3ccfe71
BLAKE2b-256 be797ffca59e69b1df884e643a25ea33bf3d617424e8243af8aac65fa68fc6b7

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.2-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.2-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 1feeb6ab1bce17dca2d756e4879b90c8e272c99bb8e9963d45bf0244602000c8
MD5 52de5069a5e4fc2ed195855632262745
BLAKE2b-256 f247bf42433636a9a3ba5f7bb1c405bc0973c7132579db1103e27f1c26b0fa8b

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.2-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.2-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 06408fbb72fa45f616bae504a9b0a61377aae11da4676b0338b5952a3a4c2473
MD5 2aa6e25dd0141ccd446faf6df1a517a2
BLAKE2b-256 be13fa4b3d73c124b7c161e937bc5f6d10003f8288b3856f92e84ee9330788ff

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 3071f094b1b4f6d52da856982b1ffa2ba728e6bd12e3436aecb4db71cf7b5c94
MD5 0e598db8139c58c2cfbcba045db720ed
BLAKE2b-256 6d7475fa0cdae10d697af859d99ed9a6e7b1c487bcd8702f1b0f86e997069c71

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 7079b3a06b0a88248a7ca9e49370afe312b7aa5d223b13951e912bce46203149
MD5 6e8ad44fafff748f2546ca711dce6086
BLAKE2b-256 125178656405c55fab6de4090ddae14fb62498d83feb5baf43ae00abd897064d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page