Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss; OOS presets documented in docs/qgam_oos_presets.md

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.19.0.tar.gz (13.7 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.19.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.19.0-cp314-cp314-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.19.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.19.0-cp314-cp314-macosx_11_0_arm64.whl (908.6 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.19.0-cp313-cp313-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.19.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.19.0-cp313-cp313-macosx_11_0_arm64.whl (907.1 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.19.0-cp312-cp312-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.19.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.19.0-cp312-cp312-macosx_11_0_arm64.whl (907.6 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.19.0-cp311-cp311-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.19.0-cp311-cp311-macosx_11_0_arm64.whl (909.5 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.19.0-cp310-cp310-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.19.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.19.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.19.0.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.19.0.tar.gz
  • Upload date:
  • Size: 13.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.19.0.tar.gz
Algorithm Hash digest
SHA256 b7ec812214729bef686c868c70654c6ba51603420a89ade2479bf4e41a74ba2e
MD5 554f565dd768c57ca9b97a6c61f54c48
BLAKE2b-256 52db7982d7448a207124c6ee3e0ec28663283493243772093a20afe9f0eb5666

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.19.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.19.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 b24b5c55af1b525d75782efdc10ba4e9e689ca5d066925cc40fdd77e8bcc5fc6
MD5 06c7319c426d05ac3742ae986ddccd95
BLAKE2b-256 051b63ecd372178407063751d5aadb0bddb482dca11a6055a2bf105ad16faa64

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.19.0-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.19.0-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 81e5302fa9c62c76a07b2f92b15539b4e8e4c36c1c8e846630adf126771eaca8
MD5 b6df0e2054283a5bdc5ffb56a7f9333d
BLAKE2b-256 800785c7ecf78550a0771eea19a2ae8704c4e50a6878cb7c7a0fac22bf79d872

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.19.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.19.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 ba00e41d02441fb47ada0557917eeb844e96cd12111d4ccd879ee13d0a89da30
MD5 13f78a933f7410edd6479cf58a6b7a4d
BLAKE2b-256 c5998b13e34d48134a37c648ad9d1c4a4adb2358612cd32e6de460d07a380bd9

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.19.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.19.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 0483a0c30cf671cf2b6871c8ac8857548f44378472766440f4e27fe2f3bc94b9
MD5 07c7629285cfdfa0968a7f52891e1fad
BLAKE2b-256 a45cb718df766062f9a34a5e3b5581cb912862152ad690bb0a5982dfc43eed68

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.19.0-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.19.0-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 617c6ad3166d635d300f37c784b7fbb92b61e343f4b19eda73db308d1e9a7d0c
MD5 95d657ead1b45b20293fc10d9e69513e
BLAKE2b-256 3701d80e30564967a53a1aae76eebd853fe6079e098fa350c97311c3e11b90aa

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.19.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.19.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 ab46b33d355af71e53a7612c1bb6dac331942ffa6a06fd286853fe3a6367dd7a
MD5 6b5a7cb75350ff90b0c25bbc6492c3c6
BLAKE2b-256 7852495ee524dc8cff168918d1eda3b9d5e964ef47f67a4165edee656a912d75

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.19.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.19.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 a2534e792f10fc13274b294ab96c13580742fb8f5159b57e069e68e348a4faf5
MD5 477110605b3aa21a0b65cf06e2917750
BLAKE2b-256 6740ba0ae19da34b455dfe32391bd62b0141ee0d52ced710fa88a21e32abca38

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.19.0-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.19.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 5c973c42e83280354eb28b82115d35e0ab65e557960fdf632298c078b2479820
MD5 91448cbbf4ce869f9eb0296b14424fbb
BLAKE2b-256 e1493676ea3196f3d44740add63029f0d905c16e70915bf3ba71de56921d42f7

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.19.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.19.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 a07279b1e30ae9a0e490cacd3aae05016a228c3fc372e4018efc4ce5b2754446
MD5 561378305219eecb74530cbf5502ea60
BLAKE2b-256 f75e9725d2681d242292da642e33fe986e596022adf4a1f3025adcc4c0cb63ad

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.19.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.19.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 4e73f57be14c214b47d302e72ec2964c1e84db18dea9518c3e1d42fc4c06238b
MD5 8be1e5dcf88f8775304d8b87cdc43f0c
BLAKE2b-256 16b1c1746cb0f6606f9bb0237db788201499cc69e9766e6e01055e0f2144fd7e

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.19.0-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.19.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 e0c53564a20d06682f1cb52e4b3d5c1ecaeafae66bd5110858917c52c597b35d
MD5 3102e1c8f7aefc1f534e0b5c7dd9ff48
BLAKE2b-256 fa0894f2f875483feb0d06976858c999273d923fc5998e9480b3cc4886e2e059

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 17110ad58f75ad8411beb81a3ea189a7465f7bf67e053d356f9147ef761cf66e
MD5 9c15a3ad0ab159f644f925ca324e87fa
BLAKE2b-256 1fa282dd11a05e3eccdc009ed78c7cc513e73b06a8c4acef734d685742b65252

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.19.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.19.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 7b7645acf537354c63d4f3bfd893c4666140e747d9a4c6dd99458daedafb653d
MD5 fa987468b36ad8ba5651c932ee964852
BLAKE2b-256 8fb751beb575935e1db3bd95f784b949008348db06060064e90173ed636ea7c6

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.19.0-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.19.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 be63d7b6ee7ad748847834766668e1d6074c11581d06e457cf931402dde3fca9
MD5 90c1006e57bd1e84a595003ba9bf39fc
BLAKE2b-256 078a61e613c0978df652ad981df2724a6f76fdce36448fcd4ca510e94e543ee0

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.19.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.19.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 881b0363a9113b829080a55965e171b35d0c821e5cab9a4cc21c80ed021ebd5c
MD5 c4fe47d7c87426939fc38316c09b7beb
BLAKE2b-256 49f9f8034257679411104b3f5a50bdd8c77d25b4b6d46d1b5892ef95985185ec

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.19.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.19.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 859e454046e6e76b9fe7950ccd61dd9e471c09211a33d203b7e168f440dbc4ec
MD5 20528001c0445bd9fbf225e3a983713d
BLAKE2b-256 9d3a8212f2a4e7416366192b5c834566f7f7d840e3204090c7841b3128b9aa50

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page