Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss; OOS presets documented in docs/qgam_oos_presets.md

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.23.2.tar.gz (14.3 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.23.2-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.2-cp314-cp314-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.23.2-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.2-cp314-cp314-macosx_11_0_arm64.whl (940.9 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.23.2-cp313-cp313-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.23.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.2-cp313-cp313-macosx_11_0_arm64.whl (939.6 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.23.2-cp312-cp312-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.23.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.2-cp312-cp312-macosx_11_0_arm64.whl (940.2 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.23.2-cp311-cp311-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.23.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.2-cp311-cp311-macosx_11_0_arm64.whl (942.1 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.23.2-cp310-cp310-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.23.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.23.2.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.23.2.tar.gz
  • Upload date:
  • Size: 14.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.23.2.tar.gz
Algorithm Hash digest
SHA256 3869e26c31f78c90715099a57077c100a0d27d2af4db38723ad03960c76caf9a
MD5 d35586fac10a9396c5f41e7ca0308609
BLAKE2b-256 db7290f7fcaedce2983134b2ac62fb869f26aac546d57a99292e5304e26ab30f

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.2-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.2-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 c1d143149a13339f884fca16e844c671f6b4b1e8a70e04c222f028dd3eeec927
MD5 73e9a9a1c7f9a76639e6d3a1773fd742
BLAKE2b-256 13b3204480057b1b98cf9d1ab5b152148cc6649bdcf95fecab28889e9d87c7e7

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.2-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.2-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 0010d3fdbba5ec3dc6aaa3ba79637eeaeb6d738bff6a80569594bc0f01e17db5
MD5 baea761bde10fbd9c158dcf965825f1a
BLAKE2b-256 a34c73d40623fa21cd0869b584bb4835d889fe0f3eb4c8415a517d7dc9251845

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.2-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.2-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 d3ced513991e01ad8899c4a60832eb652a74fc2229d0b85dba4bc8901dece3e1
MD5 6b783eaf88a65baf29b2824f4a4707f8
BLAKE2b-256 529e47d636b88099adba5610ed8e91d3fc8a8521529b69595ac1bd070f60d19f

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.2-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.2-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 16145cdff9e0364efabd24222dc55b5882556b4e3b8bdcbc0875efd2c246e491
MD5 3b980d90ec44159ec798cbe1158847d8
BLAKE2b-256 da0801d75b055354aec059309f24d5f7851aa79d35ece972ba87e6f8655b25c1

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.2-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.2-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 0b60422cde585e11b842d2e654fe1252969b07c5479454b8037b64a8da33a25e
MD5 a019e68f0f00706ba70fe63df91f5934
BLAKE2b-256 025fa1c4b359d580fa0a63271261cc56d29d2b9be16024d9c60442a37de612fa

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 00b61bc532dcff62ff2ba4ddcee9a2b60462d813e746e2477a852574a1985c65
MD5 0d8a02d8b3926ca1276480043b1af6af
BLAKE2b-256 ce23585fb0d503fcff968bc2387d7e2750eb18bc1fb3fba618a6e76c3f6020ff

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.2-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.2-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 1d066c746ce389377ff4ce7614ab8ea0295d8126459e60eea16681f4114bdfac
MD5 56015fded53ccfc63d6e20794d3a4259
BLAKE2b-256 c1da660a9737f41aa44347fa315af1b4f2cf1dfe3472e93c678b62fe4697d3f4

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.2-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.2-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 1c6542c446b5e7a5ce07c8972008269e3d8c7eb826334978352513d952f3b99c
MD5 3935d6d1c7af4adfeb32ec056e8b8cff
BLAKE2b-256 28a5d3685549fbe53f2a26e4d6f8598c51ae88157fcf58aa9e12a406446be8b7

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 7a03f28c96bfe937dc1236b1acb9e069c18cbfdecd24b1c84846c55af75b095b
MD5 e34bc6a0f6cfae6c16185acd211ce36d
BLAKE2b-256 f39400a813e23b95c026f59773e38775d477e046df3364abdeed82172d8a1963

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.2-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.2-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 87c1a4b55e8affceb16b5441436bb63801e01751e58568df20643fd9d9be6f4c
MD5 5b5ecd843475cc401a5b9ba06fd198ab
BLAKE2b-256 ed0d950624338af7a20f65a2aaef186d2dbf67212f59e50d7c79347ce54b03e8

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.2-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.2-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 632231255b3c48153a5ec4e561be2641dfc756b528f57a7aa1c3b1832fa846bd
MD5 6cfb47c37b3dd024b2664d21813326d8
BLAKE2b-256 a413f6752dc9c533db63ca9ff2cf767ee240a52820759b7472e7bf9761e6d141

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 aa16cf111e775c534faa60882e5099264c955a50bf64c19e8393ef8de3fd1d54
MD5 533130a64dec3e8c5c469acc54346630
BLAKE2b-256 899daedeb32a399849013451b36661f848662395a8525c2db796ef6ae383a6a9

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.2-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.2-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 14aef7868f488ee8ab539018c3f2fe9a401812a4a8ac2e0b9ba5ea9bf4c84642
MD5 b2d80b8f0e16c16971a223b2bb9f2ee2
BLAKE2b-256 c59252a5c35f177af9a2184833aa8b8c8dc2812eb0311aad2624dc267f4ae95e

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.2-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.2-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 a2e867ed34890280ec90d38c08f1ce0972dfd9f02a8c698de1073afb08f62aed
MD5 9fe80283f8cbcbcfeadaf96988cdce9a
BLAKE2b-256 045f2644499258488408fde22f94e9292484d3ebb6a579c6a3372e290399c7ab

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 1c69836613f7be093b2529be57c7da6fae6f135b8517154aa1f5129928485103
MD5 fab6c57211e5cb519854b82d9802ef0f
BLAKE2b-256 034bc3f457614521f77b1b20575a2455680a95b5e660660f5681cf0dfac9e64e

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 4eb041f982f457b6f84d629e65a8e678af0d0355ffa9d84b79ac12b4c8028b7f
MD5 1a81edb753d0511fe54fdb62dcecb7b7
BLAKE2b-256 785650e7cdefcf6d1e347e10d3aeabb5d934fad5012ea74687a5756ab6faadd0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page