Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss; OOS presets documented in docs/qgam_oos_presets.md

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.23.1.tar.gz (14.3 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.23.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.1-cp314-cp314-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.23.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.1-cp314-cp314-macosx_11_0_arm64.whl (941.2 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.23.1-cp313-cp313-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.23.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.1-cp313-cp313-macosx_11_0_arm64.whl (939.7 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.23.1-cp312-cp312-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.23.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.1-cp312-cp312-macosx_11_0_arm64.whl (940.2 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.23.1-cp311-cp311-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.23.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.1-cp311-cp311-macosx_11_0_arm64.whl (942.9 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.23.1-cp310-cp310-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.23.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.23.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.23.1.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.23.1.tar.gz
  • Upload date:
  • Size: 14.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.23.1.tar.gz
Algorithm Hash digest
SHA256 1b2079cc2f7f7a224f46bc87c6806388ed31bb305503bfca33a31351c3ec5b8b
MD5 ae0126467fadc2138f1b8ad840e2031a
BLAKE2b-256 5f1e69ddf7deab357891dea65cc0286fec0f36b1e3cfd2f71bdef4c15b8caf03

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 403109e93d6390daaad23a0cc4d3c5fe034b8390ea370424ca3f0d75a7a06466
MD5 1836093e3bc956f29e9694a754149bfd
BLAKE2b-256 5f00f6da96067879095a4ac2ab492ee11d1f1b0c49adddaef646011ee1fd162b

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.1-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.1-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 9be9748904ef34ca6f57d4ef4215556118c89f2d68a41e630afb77d9655bef03
MD5 b55fbd168ae1eec6b1efda76b02d9b88
BLAKE2b-256 5b85a43930aabf6f8701369c35ed944f63e008c961ab348783e408d54f9b9567

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 9dc9caa54b2c0a09efc10ce6ed19ba811644e1cbf28e61f8037d9827f96ae306
MD5 4ac1965a2c2a108c48e14497ebd2bda6
BLAKE2b-256 4ab8405a6a50b2e2db8845fbe749c0872700f7402d094b483ba74a967505d384

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.1-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.1-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 58884e74dbd662f511f5378c14ed9459c11ce7ac232f7bcf8f9d7486ed8fae28
MD5 3ef70eaa93ce779882e9fdadf89c8ba9
BLAKE2b-256 542eee7e84b15f83c90c2730720d826fe1f7ee2ad380e2fe5daeb577bee559bd

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.1-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.1-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 b0af2ed1d4a9da9a83c374658bfe7b3cf20ff3a52a6fe1ce54c06a0027c700c5
MD5 24c8ade0b88b9bfa34b8a8e0ce62ecdd
BLAKE2b-256 e19a5c6e2de89ed615b467e7f4fef11f6a845a0c112a29f433e656314de55aab

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 c8f5622bd6e435d8901cb6e1a6065bc3ed3070a936a7be251256ec84240bb77a
MD5 59078e564058ac799b890fc68ce9935a
BLAKE2b-256 3fda47eac0ae41a45f165fa21bd439b6e33c8b98d201f5ebdedef416955f8ef6

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.1-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.1-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 e5a9cd83c33fa2e28942fd929351b4579416dd510d6dd59e3d9b9ef8d28082da
MD5 25bc6f6920fd4e991f58e16ac1c82dcb
BLAKE2b-256 175ff71d871365cb0577938253ea4eb802f530c7789dfd8870798c9ffa1e5550

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.1-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.1-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 0898d373151fdd3d1175fd68684c5e67005457a5df0bef47a84084771cc5daf2
MD5 5227fbbef2040b0cec9b00bcf4e53656
BLAKE2b-256 7de0bd54fe266851babd811caddf9dba4fd05d3957486142909fd774ce16a9c2

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 c0e8c936139ad76f7cbe8282ab63762a217c548de054e6d40724d9c4186bf1b8
MD5 182ceef1784503d5d2df0ad7ae30d827
BLAKE2b-256 9aea4b4517429a445e8347876422bccb13a0f4a25b5d9fa9f5e761b7d48fcb80

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.1-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.1-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 09e878c919c5e1489f0ba53eca5e912c5a0452b18dcc30e253d70210a15e6f6b
MD5 d44cd2549130f0dcaa8f56a2d8fce34a
BLAKE2b-256 af1b99b402c42672d8ae22528f8d0f93224f83e8654ed39d410dba31ab1a1259

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.1-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.1-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 6ac2140a19801f409a0c724f84afd757e43520223fc29009610312e2df12e1ae
MD5 8c9e290e5e168e290e41e92ef8b173d4
BLAKE2b-256 98dbc66dbdd6d694a2f2e65572e010749b6bf03ce98f54f20fe200f4ec7ca2a1

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 b9e1f01b6bc48d3963af24574ff0328e6990b62cdc157389bffc06a6283e6960
MD5 29964827dde0af6568bbea490ebc8e7a
BLAKE2b-256 1cb10ace74ddeb27680cfcafb19883f1e027bf49960b1697798c1b66defe900d

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.1-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.1-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 4324546f23d4236b86e44e72fbcbbf5d1db5c98b947b3f85a7b7c62d4b51ef02
MD5 e49943e2673654a8138413d3e2b4082d
BLAKE2b-256 5876a37602be8e44eb1a26c68fa29e9d4e28e132dba40e5ee3dfd1ba0e6e912f

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.1-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.1-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 dffe6e1f67ceae67ad45a988f0565a349dbba6219b4d2f8ede9fa550714245dd
MD5 fc252b65c7129dcddefa33d9c38e407d
BLAKE2b-256 09c05d3ac36417c75860560eabcba9a067aa5156b2f84d848f9812bd478885cd

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 61c6cb756eaf2ed721ceb7ed95702e068e678595b8a2a24f97690e30dbafb47e
MD5 cc08bfd4488a82dccda8c8c632134efa
BLAKE2b-256 dc75079dc4747b99c49ad8ee0c93adfa5e5fee70b524bdd7c9e66a42d7915fea

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.23.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.23.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 c3052bff57272ccdec6bf7d6381fcb4ce9661a5533d6086346fe3e88744fa4e3
MD5 cf4cb31aab8148a905ae01caea3aca0d
BLAKE2b-256 d265bd61ba0d18ccb03a0b0c39e3d322d415bfd9b12288d6639ca0c3e10bae7a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page