Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss; OOS presets documented in docs/qgam_oos_presets.md

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.22.3.tar.gz (14.2 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.22.3-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.3-cp314-cp314-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.22.3-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.3-cp314-cp314-macosx_11_0_arm64.whl (939.0 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.22.3-cp313-cp313-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.22.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.3-cp313-cp313-macosx_11_0_arm64.whl (937.9 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.22.3-cp312-cp312-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.22.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.3-cp312-cp312-macosx_11_0_arm64.whl (938.4 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.22.3-cp311-cp311-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.22.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.3-cp311-cp311-macosx_11_0_arm64.whl (940.1 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.22.3-cp310-cp310-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.22.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.22.3.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.22.3.tar.gz
  • Upload date:
  • Size: 14.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.22.3.tar.gz
Algorithm Hash digest
SHA256 0e32e26d083c699c67552fd083a4204d669cd39d90ed30943e39a943785d9333
MD5 5163ae4af9cc658f98c0545b89f03065
BLAKE2b-256 7f0b8618049f61db4e685bb546b5d3edb2790dfb164601dc8d2158269dc017ea

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.3-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.3-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 ad507256a46d65e90ce14189c5bb3c6e86c05e21341a198e9c1e33d1cbf9c6a5
MD5 38162e78421791c8c66ebdce02e9ad15
BLAKE2b-256 67478be33194701c7af0cca74ca02ee9bee499e865015a867d08fef776893430

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.3-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.3-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 52579ea1adb6502e3fca896434b0b70acc82bba6df5dc5b91d3a74cb08067d4f
MD5 b5cf043c2191e06a7c12447c303348e1
BLAKE2b-256 afbdc0f56957eaffb34cb73a311eb35287a6073a08c5f8fa94a97ad7b64fac56

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.3-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.3-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 539a790786759a8394ac30580adac250eeabdde0587862b4b5c62c829d921a52
MD5 a1ae255658457616708a368dc7c069da
BLAKE2b-256 e92297f379544869704c45c7c4cbd023b2c2aa95e37e3af8649c5e59bba985e2

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.3-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.3-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 86e555ffc0ebb3db8172fc02ff57a1e3a24f792c34c61d18dfcc4d05c567873b
MD5 e3b6d4a822bc6cbdbc2272304183174f
BLAKE2b-256 d637db30dc4ed618d6026ac62835dc34bba711a4a24de498a89202119864d4d6

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.3-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.3-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 89c07f7f8988ab471d0c632c43ca7166aae9ca6d4071a21e105335c8e9104be3
MD5 636f0ab13279fc5af634c26d01c9ab30
BLAKE2b-256 2868aa8397848fe69e69589b6c9e61503c027dd94bb82e864def8163a07898a9

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 2548b12e1083a59a55f1c46103f426fdbc66936db328636c55c40fc92aefba5d
MD5 d6396cc60b4df763b3ee32397ff44531
BLAKE2b-256 90b0894c2a78d02bb2b9edfcf980045959a39dc8e5cbd352b870a2921ed34d29

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.3-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.3-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 54497c3a96f39425ea4930a14fd9bff6ee484668c4182154501248fb3c9671ca
MD5 55a8f32917545932eff3b51de24dd910
BLAKE2b-256 7e8b646749dd7bfaf609bed996ac731c3f6e3557e4105e7afb12953f429cb276

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.3-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.3-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 7112a717271cb126a4444c8b423ae757a1650f2b0bac730a11d44e83b86818c3
MD5 3b5ca1884eeccd9e57fb802965703027
BLAKE2b-256 19c97f7aaf52c0f0439c9980a0bf5122780d1c9935757a6f00f86e38bfa8e10e

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 eb97bcd14b89e2ab98dda145e5a574909d82de218bef673ec6a008609eeb799f
MD5 47e3f943e4abe582b574757a845832b5
BLAKE2b-256 fa9915c95228c3a670a9fe724c5b4878cdd01dc7f44ba28c24d7460549bd5f17

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.3-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.3-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 aa02a4ca907fec62b09073a478e907142cb0946ddfe9a3bc52982ee483a798c1
MD5 e42cd03bb9b615c3975e517c2fc59003
BLAKE2b-256 41762fe04998e1caa3e8181a8941c93a1e955524c988bb2d0c3343b21fd4b770

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.3-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.3-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 52b8ba690724c58a740a81d7613751fdb8fcbbbe474b9c0d7d472f3164f5fb9d
MD5 f4886cb697f75ea24e55e3763e8f1644
BLAKE2b-256 cd799fd8b8c30b603be026302db7f5976a2ece8daef1b11c46c805e4c801a7f0

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 ce6e8cba295512b854545bbf271f741e92d15040c3cde969bc71e55c064e5ed4
MD5 b78db92ad91bf17659ce2d932905ff67
BLAKE2b-256 d24a515cc2f797dc928d98c649f147524dbe6a3b64065490a856c8f5e56a12e4

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.3-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.3-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 7fd3fc080c0d4c9198099ee091556aa0a628a45edccbbe35408451d550fb7127
MD5 cf7d88ee3558fdaea442ba088afd1aac
BLAKE2b-256 329ffed461fa7bb33750383a827925af3837b4708a917cce0c77784d500e4e2b

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.3-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.3-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 34a3d7ad0544ddcfc144217972cf9efa57159868cbea25802a22315d8dde4a87
MD5 95d4c4c85a44cc9e0dd301dcae45cfbe
BLAKE2b-256 5dd32a2cfe02520b91c7be9cba9a7fa3bad29c32fd48751fbc1c126fbdd70dad

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 0873ed63e991df41df1e30ee92df11ffa1ebda7ff006827cf5d4c8573a9b68d3
MD5 564b48f7f13f27ac996dc635590b11c7
BLAKE2b-256 a77a9a8c7da7f31150e7dbe13af0e13bc32e74e49eafdc5e6d1dc5e09d5fef4f

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.22.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.22.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 251541bc99c0cbbc0fa0de4aa2a524ea4c9411440faf2692ff640653d2d49da7
MD5 e955934afbf243a9c9d8a9b53662f28f
BLAKE2b-256 98a14d44ab54e6e841933e17baefecff0f203bad1c2d368b4363efaf6a45b7a4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page