Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.13.0.tar.gz (9.8 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.13.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (976.0 kB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.13.0-cp314-cp314-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.13.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (978.7 kB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.13.0-cp314-cp314-macosx_11_0_arm64.whl (797.8 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.13.0-cp313-cp313-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (977.1 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.13.0-cp313-cp313-macosx_11_0_arm64.whl (796.0 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.13.0-cp312-cp312-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (977.3 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.13.0-cp312-cp312-macosx_11_0_arm64.whl (796.4 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.13.0-cp311-cp311-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (976.5 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.13.0-cp311-cp311-macosx_11_0_arm64.whl (799.3 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.13.0-cp310-cp310-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (978.5 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (979.9 kB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.13.0.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.13.0.tar.gz
  • Upload date:
  • Size: 9.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.13.0.tar.gz
Algorithm Hash digest
SHA256 9f40f92983b0e357ad0e29d6ad4d432a4dee5e15fbb43b47a75a0e4d3f2e2974
MD5 8a9344003405a5f36a16b2f00166db88
BLAKE2b-256 2067bc6bfe97ff6ee790882692c6f026707d26d90ed4a4c8613a44204e7364cd

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.13.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.13.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 9e576dc73ba0ec49f516957e844c9cf9d5c67c48852e50ad7b75739920f84fd1
MD5 55bda88b0658a7bc70e843a160bba2a0
BLAKE2b-256 2b97efc73a5e78b5796fec729450c271f38d98a38575c3841a89a019a1fe2ed6

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.13.0-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.13.0-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 8e04acacdc1a1e6fd6fdfe8b3942963b5a2d4b795c3b000b9ef8653ed6600e2d
MD5 86306786ca0a898e901dffcec7d617c0
BLAKE2b-256 d561ff06736b38569afc87f6cfc7ea1b1758eaebdf706dd0d6bbf721e7e7a32d

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.13.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.13.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 4675d94e4efbdf86ebe45f020208442d6e4643d3676cb8f4b8360cd7cddf8f94
MD5 a37f6bfab079b3bab40f60b882d0b9a7
BLAKE2b-256 c326b6c200dc6836bad7ca493fd92564401254b235a5c51f28f75192f7de2153

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.13.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.13.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 224e3bccf2b39a4878d4b098d1bf7315b384864bee86f46181f87f25cd149b97
MD5 8e02b7580a7e357209bd06b9444a0464
BLAKE2b-256 47ca8542bf651ab5f9ccb1017a9480da0ad779f7e47f8d7d800b767a79a0e82a

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.13.0-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.13.0-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 ab5165763607c73d81ffa1a2015363a36a9b7447c054d503920c05a368d01394
MD5 209623638dce4ee8e4f6915139757c0b
BLAKE2b-256 ea51e68e153251307594bc5f17386b67e2b13212e52fd03442f683777ad6423c

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 7758fad2ec9e5b11946f3fdb6acfbb4c829b3d0cd7177a19cfdc92dc53a190e1
MD5 8cc778ccf2af52f84013945c1998a5b5
BLAKE2b-256 4bf711136a0522852d13ad60c99113af112b672852a1d69c7ff1851ea4631887

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.13.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.13.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 f8d363447448a53484171c0fd0cde681bae9e3921e94676065bdc808d0eb32f8
MD5 585a47036a853b9e416f01e3a8e5a24e
BLAKE2b-256 63091742fa13833336a34de7c0bed8e14a1e7844bce2b271c3a7d012aa68f73c

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.13.0-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.13.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 9759fc127ba5ac4e1226e54d9fc002f67486d2b41f52eafca50e4398b1ff606e
MD5 1a82baa204fab46d02da914fef73258c
BLAKE2b-256 74754cb228ddca925929fa55542a3cca66b3e615996465558f5eba69540b8068

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 23a7aa597fce00254a624b1826694a14c67df5bcbdcba3d4e204f9a50212b184
MD5 66fb92b9ec482c0d19f83a4f7adaf58c
BLAKE2b-256 d1d34cf1e0cc54f7936358f105be4028f946f369dc8b94aa8fae654995feac47

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.13.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.13.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 2c9c58d4e479dd041cbef86249bea1d9d1be78f0fd65f7070d420ef5c1ff06c3
MD5 5cb822a42d0e1316396cd0d2ca99ead1
BLAKE2b-256 322e89a9795ef4aaf8f0fe83693db3b3af12f1b43400c67377e93325f5684397

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.13.0-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.13.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 2db9acb454448a184cb8aea0e891f62a5c1396822e4e7381ddec7029775cbd3b
MD5 9a0b91099b4c9696d3e389d7f8683154
BLAKE2b-256 9c1d70cd8c7fae6f127cface9c88cc6d8693b953cd02c84b3cec851ad486e5ae

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 578ea3db28efe188e4545b079bd82e1ff1996a4342cb03ade65b7b051677dd41
MD5 28d8a85e8bc50b60e532694ecdc91153
BLAKE2b-256 27f6cd1e12d1393d255750a4a85b957c7f091b16173c8186647116a3836f34f4

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.13.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.13.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 3d0740a3a868cdbb7cdf9dcbb2f24a0c067797da4e5487be545c65753c48e116
MD5 66fdefd969d80b72c247b06eda3fc983
BLAKE2b-256 de470b97f5563b5ddf28f879e4338afcb260dbc463adc6b5e461fda68ec5cc68

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.13.0-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.13.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 f9c6acd987f31896664f5e399190c65d4081f86c318164fa3ac235e2400247b8
MD5 b3fb03a17e1f1cb4321886568466ea31
BLAKE2b-256 bdb4215f5d36d040a43ff53ccf1765f7be3ffe7cc9addaa5894b74fb1c186673

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 a75bc74c098a6dc76fad0d190347b1ae0c58fb6896ddbdb90d320427221c851a
MD5 51ef30260e05ff9b64f0ab2a87b08570
BLAKE2b-256 90f5b010861c9c66965cc9f82cbaa9918592e004b3396993c5f67b4c921517b1

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 309f979ea86607682cb5c6318974341f86d2cdc15d20875f389c85250e80920d
MD5 d457c923a74eaf896cc10df573a83772
BLAKE2b-256 4e83ac171421fe4ed0ced18c1b859223e7817a896abcc58026e42ef2638d5ad3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page