Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.12.1.tar.gz (9.8 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.12.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (973.1 kB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.12.1-cp314-cp314-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.12.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (975.6 kB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.12.1-cp314-cp314-macosx_11_0_arm64.whl (794.7 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.12.1-cp313-cp313-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.12.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (973.9 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.12.1-cp313-cp313-macosx_11_0_arm64.whl (792.9 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.12.1-cp312-cp312-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (974.2 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.12.1-cp312-cp312-macosx_11_0_arm64.whl (793.7 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.12.1-cp311-cp311-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (973.6 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.12.1-cp311-cp311-macosx_11_0_arm64.whl (796.1 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.12.1-cp310-cp310-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (975.5 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (976.9 kB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.12.1.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.12.1.tar.gz
  • Upload date:
  • Size: 9.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.12.1.tar.gz
Algorithm Hash digest
SHA256 e8a3a3330e553d18db52eb685db6cbbd3d8dbb7bf0751fe326676d477cd81e1c
MD5 d94699de502892629ba02dcd60372820
BLAKE2b-256 316c600b86ccdcc0e70e99acafff36d024ae151e510cc07ca9006220abd72c01

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 249de6916a6f32778f82d74298a9e3f05e2848012db9bcc4e22eee9828b9d2f2
MD5 4cbf11ee800354c31670a1ae8839c535
BLAKE2b-256 5f5ac300ca2a8e582bebf14be75115bd258c1e78b5790e87b9c1287928badfb4

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.1-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.1-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 0eba7eb7f12f66bd8fae3c51b59e53d5648d88b77a6a08c6ba995ef37d322b36
MD5 27d3c45433754308a01d683dc120b75e
BLAKE2b-256 0c598fcb9e55a94bcc042039d014e4ca02551f1bf241b52283d86872599aa081

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 5bf580716af8c32304cf6ccdf5dd63caf4a2e93665e4411a3f36cf0ae89494a6
MD5 93551ffa7853d1ecd951c05499d507d5
BLAKE2b-256 c77962592a5aa20ee3fb12d02ecc1517de7ef795a4131097ec52da72319f8666

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.1-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.1-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 5baeb632edb6aa763bac5bce648c553bdfdc25fe1a17adb8440cfb6e6acf3f97
MD5 797adca33beb73e2608e405d4dae13f0
BLAKE2b-256 5312e27f2b98060d3720ae6b33f45a50bcb01893b4b10c8a40dc695579b9909d

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.1-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.1-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 233a20e8de8345e4bfd4dca8388be64572fe8d6dba7917e532b94233a398c830
MD5 ac66fa999a3aa470b3db25fb395b7aab
BLAKE2b-256 f7139fe239e8c9b3def056d93488c66edd3d18d58c97d0efce29dd8c9e5aa6f7

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 73e88f49f789f0b4e80ffec59e2e11564f66e52b49466f6c96f93107d4d273f6
MD5 6923f116a6e603929e4fa442b21018ad
BLAKE2b-256 7c9b2a8b381a1de0af720312458a2ed05e0e2d999b4a0b3832833ba13b1eb39d

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.1-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.1-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 6b4b93c55539cf91c8545440318bbba17668979bb6d0e14df56df9d05e7c53bd
MD5 99b6ccdf028687191f79b98f0a2aac25
BLAKE2b-256 ea5277a2a91cd4b31d8085d8e15c763ff9b6c2b549e2da8888bc03d6a9936fe2

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.1-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.1-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 c01f094f5ab3d2cc81496a43c9d338182ea4693bd7566b6dd209e48b9b900d42
MD5 1c8be17d1739dfaf6696c145b4f81c3e
BLAKE2b-256 c37c8afe11f3dc724561190c9658412c72da65439ed23c376a19daf9fe6dd7e0

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 fd2865abce65e2fe7ac7e705ce7004519be17660e5199becc8602b4266eb4236
MD5 1409322ed1fde6aabdac1f8b6398fd40
BLAKE2b-256 e3aa550ab9927e10ca877094671c75e390c2189b537dd8e0cc2de8b9e7e54c2c

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.1-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.1-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 9a22c43748a8833c551c093f61087e253668e27c02dd9ecc74baaacac27bc10f
MD5 180abb9ee731bbd6bc72bacffe63e2c8
BLAKE2b-256 8091b51f95b108d2036b9c2f57f2d45a789d1eedc31b790ab6d059393ae5ccee

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.1-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.1-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 003d6e6a84920e6b80eef12f4b5c09c19d654788d12b2ee2ddfd9b0fc3d67b76
MD5 076d700aee8177646faf19f28ff15d97
BLAKE2b-256 52d24ee232912d63b3cdd37702b33579c70c4aa06aee1a030065ad758e8beb95

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 307e228c6690a1bfe42ba079600d78a1aabac217353d5945938896172ba1d248
MD5 29309dbb00dcce5aaea2ccc41be1497a
BLAKE2b-256 e3dd2011e1009a4f3f524c08265646b12c2c480a5051366c82a95c5fb8d28624

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.1-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.1-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 50da164b28401bbec153948d98377c579395352fab948f6754f23da8fe8e44cc
MD5 eba89e24c7d035c811fa0a5ee17d0cbc
BLAKE2b-256 0cfb02ee8156b90297880c1dd6981e9a3d4db337e80eec9e4f3d80dbb8657743

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.1-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.1-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 bfee9f86e3064498a2574e08a2a967af2c2713224bdfa771c6ce13b81f746309
MD5 746985561674b3f3ac3cb785b549c308
BLAKE2b-256 2e0e04c26473ee2a9916132e376649790ec0f53b1dae60886fb9d0d381a42b0c

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 acc4a7b4d6711b3ae5c73e82cca0db55f56e950788c76f134d3bd95e37060450
MD5 8683b242df9d75ad3d4ce876262499aa
BLAKE2b-256 865c3cca169be0d427425b6dd33b9228a35024d4a1c15cbfc31ace5706249e69

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 3e848df41ca12de831ab3e62141b1bfff68f0a1edcd49601cd2089a261c30016
MD5 70372a5dc37c82be6ff8bf5c92e6bac9
BLAKE2b-256 c767920659fa707aca9e25684d9536d79b0f8918b59bd5f548f9943d75760dee

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page