Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss; OOS presets documented in docs/qgam_oos_presets.md

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.16.3.tar.gz (13.7 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.16.3-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.3-cp314-cp314-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.16.3-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.3-cp314-cp314-macosx_11_0_arm64.whl (872.0 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.16.3-cp313-cp313-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.16.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.3-cp313-cp313-macosx_11_0_arm64.whl (870.8 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.16.3-cp312-cp312-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.16.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.3-cp312-cp312-macosx_11_0_arm64.whl (871.3 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.16.3-cp311-cp311-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.16.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.3-cp311-cp311-macosx_11_0_arm64.whl (872.9 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.16.3-cp310-cp310-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.16.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.16.3.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.16.3.tar.gz
  • Upload date:
  • Size: 13.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.16.3.tar.gz
Algorithm Hash digest
SHA256 64a1e8c690709231e4ec16bdb71932d3ca12f86954acc64ccf1e2b3c5c2561ad
MD5 8966467be1c60a74f7ea7edcb68a3682
BLAKE2b-256 cdb7620dc9b33cbbff7a8ce4b16816a93f47772e5de2c30733ede0badf0b3bfd

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.3-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.3-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 450051b95b1310fd11ac7b835adf631789a281a1d00d49e73cc60c5a0e307298
MD5 09fa99cd8c3160939b1f4668cacb32bd
BLAKE2b-256 f6a6ff2971c55cdcd093784ae62e6b6659157fe4f448c4eaa02f669137f7b4dd

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.3-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.3-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 14553fbd273a0883c07d591f13239b7f099e707f6d3cb25480fd08099b6de7a8
MD5 44c2180e04984052e292262f8c03f19e
BLAKE2b-256 8479844cd08e7e8eca31444dbf2b312498616b4c289dcc7b1ee5de9d620d8545

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.3-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.3-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 b77d474b18579d2e59ceffe9b2b7917229f943ce3366c4255bdbd9d6c3d5a695
MD5 e824841cf432ce6588a3da5723048749
BLAKE2b-256 96e59cb3c5e774f31e569768f78cc4a18aa7dbe70b4c936d81fc49232eb11f7a

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.3-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.3-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 c07a00fdb15c03562d63bf6bc69d75bf46f6c28732fafa87cc86622e94f12ad3
MD5 d98355399162e953d11741dd727793f4
BLAKE2b-256 c346af9a7614aa6d749d5995818aa82cced4dabcc511fa86cf46d43025ec2946

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.3-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.3-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 d1cbc215c6bfa576a0badb7709baf16de8c1d0090456a44e85384a5fbb902172
MD5 ac9092408dd7bd089314af84284b40b1
BLAKE2b-256 f39caaf7c84c2f0f3d1222596aa9220e092200f86879a674ca2ccae420325e95

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 a8e6b2839fb9212fa870ffc710c4f1fa50e88f6667eff20656b311e4f831f190
MD5 71ac9691ed558cefeb8dce730b9da528
BLAKE2b-256 f15a6c85af6376d190654313389cd2d17bd712f7c7ee18dfd31cd97fc94e1d96

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.3-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.3-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 3a25b245582cd39c341525abbb2fdff6baf197ad2cd63627b4115b3ce8df52a7
MD5 e0f39fbbd10abc6e8379cbedac7c8e93
BLAKE2b-256 f92f5b9a31a92430a55d67bb9880433522e6b09f321869bcb49d10e837d2f511

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.3-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.3-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 06535bda78e9e10955dbac7efe2df24d421489b0fb846b8cf76c4cfe71ae3775
MD5 480600aa93b42c3f8c1d96bd2aa0bfa3
BLAKE2b-256 82e0964711eedc9373372707ce651a132bb2349191a1dc030ba59acb36e53964

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 fc1f03f0a97156e1557ed25da626fc63fe647694d88468949b488421eb100c77
MD5 bf512e02e7c69451cc244d99cc6029b1
BLAKE2b-256 247df7b9f11daf0e86816580ee083dace210db7e15b0e47b751dd189572086f7

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.3-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.3-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 91c8c790ad2f7566ed43338945ddbf7e90ab872aeb99a025c3e0f6ec931f489c
MD5 f9407085d421a9d308eb80217862f38f
BLAKE2b-256 64735bfd8ed5d7a07008155a70239fba8f7d7fa262ee635028d862a0342cb2b6

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.3-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.3-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 2d7f19823acd9ee99dfb6bd95bb1eda7be44a36d88b202bc756bcdb27b3ddabc
MD5 022265f878c9cf81c1b4add7f00eeef5
BLAKE2b-256 57f3b7d9ffb8da4017df8d6361d04d75f8f15961f5d9f5f890e1ca8e80ba30e9

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 c7fce856a7d6e88b52a813b5e22724064d7c83dffb13c2d6ed1d24ef3887267d
MD5 51f5c40710f94fa8ffd32944b4451630
BLAKE2b-256 e64e55257e94a31adaad6fa63d9fe728a9096830993382e5d700d5345e73fc98

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.3-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.3-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 935484c36ee4851abd0e21d910e72e6c9722869387e3dc3ce5d1d9210b593409
MD5 72a138fe03bb8991064d0966b8cd60ef
BLAKE2b-256 cc0029060be587f1b4125cb664f8dc4d324249f589a7b26eaf9fd31e5840c0fa

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.3-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.3-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 11bf20cb40ab8de64b8e507d7a058d19c266404a8e8f4831f5572f1ef0cc19cb
MD5 59ed486636be5550181e93a331384d91
BLAKE2b-256 26ed4bea21e4ec1f72bfd76d48ed3653b689cb2dcc46013451e12d9aabeb8915

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 2896ac1be87d9fcce20e52a25bedb65be65a1e42fd11ac12c3e652adf1f2f989
MD5 c449e0a6242674d1a90e4ac012ea13ef
BLAKE2b-256 92d67889215f0ba553a3518ed475971e273ace9df3d597397a387b632193dea8

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 125c4faff84f400c16604f16bc19b5009f8cc92a4e520f5bede22d1deb202670
MD5 15c9ab77984d34d6009407bf1a74877a
BLAKE2b-256 33eee4ded837412ccda3dd1ccacb00b77291d2912efa6b398fa795eb64d38a95

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page