Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss; OOS presets documented in docs/qgam_oos_presets.md

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.21.0.tar.gz (13.7 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.21.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.21.0-cp314-cp314-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.21.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.21.0-cp314-cp314-macosx_11_0_arm64.whl (926.0 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.21.0-cp313-cp313-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.21.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.21.0-cp313-cp313-macosx_11_0_arm64.whl (924.7 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.21.0-cp312-cp312-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.21.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.21.0-cp312-cp312-macosx_11_0_arm64.whl (925.3 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.21.0-cp311-cp311-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.21.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.21.0-cp311-cp311-macosx_11_0_arm64.whl (927.4 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.21.0-cp310-cp310-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.21.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.21.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.21.0.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.21.0.tar.gz
  • Upload date:
  • Size: 13.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.21.0.tar.gz
Algorithm Hash digest
SHA256 26f1245352a9595cefb888e073c7d0f78ab6c7d76d7ea9afc0cb1f1030932b2b
MD5 2b874c10be8827a0cb16f89073e23a86
BLAKE2b-256 167fbc952b05549ca9c22e2249756e0f2a44c78b6f2a6995194fd0edb595336e

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.21.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.21.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 0884abfc0e55cdddbf96115f0f263b2087adee991b285175cfadce5f67747c46
MD5 a51686a426f4f6decb841cd9119a7f81
BLAKE2b-256 0f40c73cf9659315b45ee6c7893b9c2a14a34b388264cf6cd493dbb297c3f9b1

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.21.0-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.21.0-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 bf40debbbcd775bd42cba26c0b11b9cc940c1106956e99a7b3a1b1e7f1665e2d
MD5 70aabedfff6f29ea2e273fd15423a08e
BLAKE2b-256 52eff7e89378d83661fc590a49dae444aeda951d51a973c71269bda3c8f283f7

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.21.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.21.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 6f6ebcc11a73ead8d244b59788e8914720b4ddb74e6fef70bd366a012b92d931
MD5 7f12646a7c223c375c4292b8bd172cbc
BLAKE2b-256 6194f58d07636631c873341e3f52d4f5ca14ffff2bbd97b56908a6b6b65868a4

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.21.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.21.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 fe14ed7681c6ad79ff584174ece8599ee23f527ae37fd6076a06d212c897ac6c
MD5 6b3747cc678ee048ceacfb3716942b42
BLAKE2b-256 0e184f4387595e4e39df1c95fcab40a5bfaa2f49e7061c464b4b085350baa1f4

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.21.0-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.21.0-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 6a8e25a23a32e754f33e9a5b4f57d87b7e89147700b4ee7230e2c5e47d7bce03
MD5 3624d11b14a25117f78afc465e7a4a70
BLAKE2b-256 713b3d638ae1e4eedcc8f4eae4371f9e2ca8811db5f3918a44586c8da37702d0

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.21.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.21.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 6c1880bc6a0791a213b2a423e1f5279455cf9d96f7dca83c58365adeb8340553
MD5 a456c8da171436b0fb29d31346a8cd02
BLAKE2b-256 78b794abb40b7135ac21886a52bb42a133746fd4a99a4302720f41c2e902193a

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.21.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.21.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 b99afd1efbb43a788a5c08951d9fe360a68c71d297630ffb0739e26d15eae8ed
MD5 898d02a0958401215279fe8222b81722
BLAKE2b-256 a5df248a35e9196d186bcb947d5ac65bd770517bf1a36e3f73e8ae0706084ea1

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.21.0-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.21.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 6d1bfd445645ad0604e85ed5614fd0c4db8379d8090cbcc332b911696acdab74
MD5 c5f9bbc8333209217b6ce4b10b418a7d
BLAKE2b-256 5c1e4754ade0eec87feeaa61ec83af738302f56a4c18ecb1e5eeb103e952a0fe

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.21.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.21.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 9c38aec73d9596d2313bfc95360ffbb2856f579edcf7cc6037bfd02490feb102
MD5 b1b189d6dc1b63d55dfd7f1b29bf55a8
BLAKE2b-256 ab752ea190fcccc07ccc78ca6d619aef6df37cdba37a2cd7749dbd56d466643b

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.21.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.21.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 f819c8cfebbf5a7f7b467b1120507f27e70b73e8d306981de440df6e12ef96b8
MD5 c11bd1b1576c5349c55bbe12d44d84c6
BLAKE2b-256 aabf776ad90df902e04ea3c00013b900ff214372c705fef7092f0cf3d5f87281

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.21.0-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.21.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 b916f5742c303d6fdf2c3903e25c768f7990d45b923bcb203793f40822354ab6
MD5 479ddb78d108e56d71e06d049d66135c
BLAKE2b-256 78f39967b09cd86f5f19562914d08ea5a6af7597a838b881c2a860a11ec0821d

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.21.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.21.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 a39358f1bb22c8ff88cf2b8c6999be564aa04b35a4f96580087c292bddad4161
MD5 31a7ed09d60ce1ad744000b2fd5be01a
BLAKE2b-256 03424ba80a7241c1367ebbb7ad2747e63a8da5ca791b9aa88984ddd7d8e863fd

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.21.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.21.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 466a5f25385d37f1c29e6d312319a3d037857134254d4f77ce937e4ea76afea7
MD5 2711450681ea6ff2130244ca51329bde
BLAKE2b-256 82098dfd41c05fd2bf5cc7e1822ee9855bedf9743a805513b984cfbd24459bf5

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.21.0-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.21.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 89ef62591aae2fdfd1b78dfddfcebaf586d7cfe30698874269c6b6afd5457ff3
MD5 9d559b8e03731087e45c3d1e7d12e8be
BLAKE2b-256 c25b455efba35aaa7b6368083387e577af6f289a6bbacac7869ab1ae4516be80

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.21.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.21.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 fd0bdc4436d0db2a59e5214d967ec3b45b29485fce903050ec0bcfd0c8621b35
MD5 68ab8404c15130fb73c20bcce55fb618
BLAKE2b-256 414b90964fcd5630d1f76e89ca119afd53d0900adb685f41c37154cac5e3e808

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.21.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.21.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 bd1b0aad7b5b14f69f4bd69bbb7fc879e822bc43a16b65038dcb0b74c6f3f949
MD5 1de079dc7992a1505d47f139c3a9bdd3
BLAKE2b-256 7cc7779d53112c1aff799de3dcb07359f469738429846348d58139af47be4120

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page