Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss; OOS presets documented in docs/qgam_oos_presets.md

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.23.0.tar.gz (14.3 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.23.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.0-cp314-cp314-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.23.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.0-cp314-cp314-macosx_11_0_arm64.whl (940.2 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.23.0-cp313-cp313-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.23.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.0-cp313-cp313-macosx_11_0_arm64.whl (939.3 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.23.0-cp312-cp312-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.23.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.0-cp312-cp312-macosx_11_0_arm64.whl (939.7 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.23.0-cp311-cp311-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.0-cp311-cp311-macosx_11_0_arm64.whl (941.7 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.23.0-cp310-cp310-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.23.0.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.23.0.tar.gz
  • Upload date:
  • Size: 14.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.23.0.tar.gz
Algorithm Hash digest
SHA256 f56b85ebd37e809f29f20ba9f2e67a80fae81023af170b5612aef424aaad4516
MD5 030b94582e5493b80f5d48ff9e9ade91
BLAKE2b-256 707a6c2eb3be9e269fcb92917f98ebf2ba95d07bb41e1bbded7e6af010996cd3

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 ec41282b73bb74e1c6c5c08ed982cf2f14da6c6456a9ea24f53f57e80657e20d
MD5 c71c1c0adf6934b3334546334c7d12ab
BLAKE2b-256 9209b6128c66b110bfd34889aee2cd87d69b04019c7908e911cf868536cf6904

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.0-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.0-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 ec2a0a447af9051869efeae4dfa263e4a52923ddbce63178233f85791bc880a6
MD5 f5d0a73d1101e8c7d9fe77f511bab92c
BLAKE2b-256 88f8a5978a79562512806e23a19ea99549f856bcfb9bb9af3179dc4e8bd11e57

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 d118bb4a4e497a1cea3af04642d26a29a3bc1a665375914ce70c32b6355c0a4e
MD5 7fcf71e5f9d4807a893cc783e9545e58
BLAKE2b-256 e9eb61c1ec7cfbb74ce8445de73176ec9ae1f87ad5ed929527beb42448d565a9

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 273588a79fba8bd3efa723d565b987ff49929bb4ec021e57d777867461667c56
MD5 6b44490de741df2737b06aa6ed28b328
BLAKE2b-256 40daad8600e736342d0dbd17ee748d711e319f53ff714e64e8c9bccb95a91fb9

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.0-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.0-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 d72c6cba74c97ee8594046ff8a004d7dd751a2a6b5f490ebd9e02acb3cd5738b
MD5 8af7315442882628f424d2dc9279c506
BLAKE2b-256 df4d9b024d1d8ec90e20cd35a4de90e889216e2401600dca9f9e28cd5a91f802

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 c64e2e162d8ee867b56c2b4a24deee16422d84157fd40785b5a68b4cb8c727f7
MD5 c3dd1fc8f152bcb9ed1e0cfbc9e791f7
BLAKE2b-256 f4f75cdaf255ec0555b8d30554e09ddf0e9ac777db6279a14231c4a4ede9b053

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 da08c4b9c7aa0147ec055d48d68372bf7a1a1376d1ee5f0d691205f5e6b581b3
MD5 808105a9a2f94223ee49f6916a2c90c9
BLAKE2b-256 e0339d085ebc0298e9df8b1afd3ff405fe5421907131c53f5f4ca8781ec7a7da

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.0-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 14db02c342119ecbe6b3e39024fe976911ccbf2c62499d71c953e8d206793ae8
MD5 cc99cf155f807564ae46eed54f726d03
BLAKE2b-256 f53a8a083d10cf6d2a735f63ae101c5dfcd2b8148b31d5f7f0fa3dfcaf891498

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 45cc0e2d8a402b10f1c36a4f9bb12068c0dfea485e08d9fa51a464568d0385f1
MD5 e81854287a6f8709b76babb8bd239798
BLAKE2b-256 b5dff36c84042fecc6e3a440f870466c0efa302ba6b41d1fa8258774f1f75e24

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 8d132a1cf4b189de796b3c00239c5d638be2fa41f689aaae05de4f839d5e1060
MD5 4da4ad6b14902eb8f2041af9c80c3ab0
BLAKE2b-256 ca8c4031c8a6bc62aea9afd8290416ccc39faaaf728bf05c5d8ab106040d3074

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.0-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 5660d7b925e4889e4275669d6d2761eec4cee2e3a317f2faaf92f0dd8538d9e2
MD5 945fce9bf3af757b74952724fc956f80
BLAKE2b-256 8987146d44ffbb9b37725d40e214fdcd942243dff1e42122aca0b6f1de6b6d00

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 8742532a439e3fda861f509de480cc2959bcc33b3163355608f1ee272dd38f46
MD5 43ec319e1dfe04470e208facaf0491bf
BLAKE2b-256 59de17537335e506a2ed9f6c7c11ed48134207971af5974599e691bf9dfaffb5

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 4230d124f9b5aa29bb1c0cf21b1eb3c217cf6888da92efed7bed73c632830a31
MD5 e7318dfb76b1942f3a9cf4be766a57ec
BLAKE2b-256 244a08472255a3c9160f8d9259d0d0cb2558bae55c2f4c2186f7df46669f99a0

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.0-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 fdd0296e96462a07f9c1aa60b3c011adbc2c10aeac4711adfffe6cde960f36a3
MD5 7fbc88c6365f2a020d272cc6a0babc7b
BLAKE2b-256 9d59f6c7d90c7e819761aaf9fd2fa26bfd14471ab28acbf690291bc02c35624d

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 466559fd2f7dcb8f25f6e2212b527fe30461d52a0cb14f04efe31a24f6236a33
MD5 48e5925c79ea922cd20e56a362cf4b79
BLAKE2b-256 81e7c672704485f360fcb4ffae9b400bea76d00b716f2c117b2707ae1cc6c122

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 0922a31cb58d75bbca5170831b50dcd33da4e1a43192e42262828e9a84679a11
MD5 442e25f90e58dcb6d9f23d048d6b0430
BLAKE2b-256 0b59e6d55872cd50d9b52bf3c044a937df9c0e2a4c871a355db74937e7cf600d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page