Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.15.0.tar.gz (10.3 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.15.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (988.3 kB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.15.0-cp314-cp314-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.15.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (991.3 kB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.15.0-cp314-cp314-macosx_11_0_arm64.whl (808.4 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.15.0-cp313-cp313-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.15.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (989.4 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.15.0-cp313-cp313-macosx_11_0_arm64.whl (806.8 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.15.0-cp312-cp312-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (989.7 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.15.0-cp312-cp312-macosx_11_0_arm64.whl (807.2 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.15.0-cp311-cp311-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (989.1 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.15.0-cp311-cp311-macosx_11_0_arm64.whl (809.7 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.15.0-cp310-cp310-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (991.3 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.15.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (992.9 kB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.15.0.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.15.0.tar.gz
  • Upload date:
  • Size: 10.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.15.0.tar.gz
Algorithm Hash digest
SHA256 eb44f955db81040a355fbc31e18f523aa3fa1de2a9cddff7e6d0ddcc285ac2e1
MD5 69eeb4a2b21203d0bb445d6df2f97ef9
BLAKE2b-256 db9b5dfc3116ef9580289c68ae07bfadc593f18174a04f84aa9185c97c5020b9

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.15.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.15.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 0702bf485f42343ada19adf80eac07dfcba7dbb398a931066280f224c15eaa18
MD5 38b72448682788fd9cbb6b0555aab95c
BLAKE2b-256 6c3ca0877338824f8e32c3edf948210322ad0529378f788a3a001a087028e9d4

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.15.0-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.15.0-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 8f9a6b2dcacc7a3e22ddccf1bd04402d4336299df26cbc53c9f0a509b3fd88d3
MD5 d69342510924009c3a9e42ffbdc9a1b5
BLAKE2b-256 566c0cf7f3568826199a28f45c2b8c6ce5c6b9f76bbc2655a9c899b04ecd9768

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.15.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.15.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 aee49d38df809bffd39402875af25f71c32e7021a2d49e7ee219356231124022
MD5 5988749713e01668b7d515dbb36ca85d
BLAKE2b-256 63654b4a74720a0f6a62fcc7ed516a19cea0f2dc08c7eba733d8114dda991bd9

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.15.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.15.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 c7ecd9cc7f5aeda58796535e65e24067c0c05f740906f4a6456a917cfaced9e3
MD5 03c45c73051cbb319c23f2014f2012c1
BLAKE2b-256 f21952173890a5108c327c244eccba64d779e6852f9bd95b2f88d125cc66b0be

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.15.0-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.15.0-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 a66e1623e974f7486e27b8703f1ad3a2c2e48143d339236c2613b679173cdb8e
MD5 72aa1dd44042b2787f117a3d444900cf
BLAKE2b-256 3686b74825dcb57a874ae43ce6b749dd8fcfd1f61d8448304d546da5d728d6e8

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.15.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.15.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 220ee8951f18c01d1fb3fd6b84ee0fc0a4d99ae1e7f297a319651c27c06dc1b9
MD5 562695ac64e4e1b175ccd2df709d316b
BLAKE2b-256 96a041b30e7f67bc0fe1c90ec87a23a6cb95458500f8467ea2b1763d162f791b

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.15.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.15.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 95bfdae33d14275e241f71a425161e7509cfd11cf4c1f44d9632a1862b930e4a
MD5 565d4b1ead846843d3d14f80f18ab542
BLAKE2b-256 020f67051f8d54aadd28268dde757a752788ceb3d69ddfe94815b84cb55ecba5

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.15.0-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.15.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 db0e27136cee1986aa95d4f74f4f51eeee6e04fd880c98f27d00ccd99f195b57
MD5 6012cc18107a4c6b3c654143aaca49d0
BLAKE2b-256 df27a517ceb3fd77897623d4cb22bb9132147fc71828d78d965826b0107b8cda

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 4e6679ac77b06ce606ca40a77bb324f4208739e9fc7a457fa267db9caef29fcc
MD5 922ad252a982da0ba4bad3b802cf5b5c
BLAKE2b-256 3c326238a78fc8970548cdef3eb8d9b3ecac8d936e7de3ac5657aa8565440a89

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.15.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.15.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 fe04de10587270b60e523620aaccf2d71aa1e760790feb25c4076ea6016e592d
MD5 716a9790f76ae4e105453eab6eb7083a
BLAKE2b-256 c840e0c869c0de8e73dde9f7b98730b591ee594a5811323be66dfac1efb521e2

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.15.0-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.15.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 1daf2b08615b0b37fcb93e6fa94bb17e8b7a4f9492fae5192cb876440437d3f1
MD5 4f655faa84f72380e035f33d70218cce
BLAKE2b-256 cb848112fedc07769c4085261210030810d385180592c899b2a0bd5c8b3e39f9

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 a39d3298f5ce5ca6427a4e61053d6d02298f021c0b5d1072c772874882b3b8bd
MD5 6b56b3ca8927167634f5526eb5a3ea32
BLAKE2b-256 b62a7db8f1b1d762b8f32bb635a387e0e93d4dd804a194f0b62c847fe6e218ac

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.15.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.15.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 9b5a08c081d417e6bd5cb794d5fb4d0860c1affd1ebef368a739a9ba1b6b40b2
MD5 31456a1221b7309eec49df1af129ff70
BLAKE2b-256 8554222fd9b9619ce6e22106156763bdff54d3f2e430188e5425969b87b636c6

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.15.0-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.15.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 972cf16623b13fd45d96547179973e6f7a1c46d82a658f20040b0b49b4b8192d
MD5 33eb2446d63ce2fa5a64a609d5f7a749
BLAKE2b-256 217fcab72f468b63209131857c7b80143d16d9c6c04efb077898f3567cb84dda

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 3d99129b720ab69d72b9af66489e9cf0993d30021b952bb1d40ee8a403900090
MD5 acc074dae584e6e48fae5c9e25407be2
BLAKE2b-256 c058841f0d2c5bf910652e927b59cfdadc9977dfb805ef210b3ea9a2305447c8

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.15.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.15.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 8e4bc361fb7af1d59c742778b7cd8ab35a1ea1f85738b6a0dcdf9336e5b990d4
MD5 a35dec27f9ad02e4f5abb4a65ef7e18c
BLAKE2b-256 a28ea5ee1b57eab103100de9bfa86a1288318de7f1edaffe706502744c17d3f5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page