Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.14.0.tar.gz (10.1 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.14.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (981.8 kB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.14.0-cp314-cp314-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.14.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (984.8 kB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.14.0-cp314-cp314-macosx_11_0_arm64.whl (803.2 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.14.0-cp313-cp313-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.14.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (983.0 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.14.0-cp313-cp313-macosx_11_0_arm64.whl (801.6 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.14.0-cp312-cp312-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.14.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (983.3 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.14.0-cp312-cp312-macosx_11_0_arm64.whl (802.1 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.14.0-cp311-cp311-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (982.6 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.14.0-cp311-cp311-macosx_11_0_arm64.whl (804.4 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.14.0-cp310-cp310-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (985.0 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.14.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (987.0 kB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.14.0.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.14.0.tar.gz
  • Upload date:
  • Size: 10.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.14.0.tar.gz
Algorithm Hash digest
SHA256 16478f60295344afd1f1c43fe4217de6049b41dc241263eab8adbaaf692965ea
MD5 fd38df815f1188143ff786af513fab96
BLAKE2b-256 14d6d8de2f22e9b4ce207d862a432d97d67cff304591ee880dd0c23b09237f5f

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.14.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.14.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 48e0e3962b67f2cd53bc857f1d2bc4465d4463736bcc92d48bd4e0a3142bb656
MD5 f8b850885f52deeba28b87b9837775dc
BLAKE2b-256 452f890a055e39be08ac44fc9757ebdf96ec56593c0722c072b0d85e2bcb02aa

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.14.0-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.14.0-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 12da937657c1439171884505f05266ef3621e50c6c42f38dc77b3d2ea58be9b3
MD5 20140c94e5e3a0cce47583fd7df64ca8
BLAKE2b-256 894a2796160bef6068319f05481bfde485f5ab1fff13fe97ff072c38d0efe116

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.14.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.14.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 73b51ec8ec5a5ade45316c87511746acf8bef0c9f45308a5bcabb14671ae7a42
MD5 1bf94ef5e417ec28b7f9ed6e40edfd85
BLAKE2b-256 ab22c9b324846e5f9a72b2afcbb0368d8529e85af99ad62ea151deb2a468e105

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.14.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.14.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 a9279b4c52e4ea7c0e7c7a81620d13a5ef8a0ae571af090784d443098bea33ee
MD5 75cca411ef695e7fe7fe5e02f342f77a
BLAKE2b-256 83c7f16dd09174653aec3346a4b87f77515f6c9bad8e014d1b2ca2651ae305c2

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.14.0-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.14.0-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 c692a0665f8bc40355a0a6e0aad8962531c8b80e45cb173ef5bf387650807dde
MD5 7fb335a3e142b7d5b74a774d2018c0cd
BLAKE2b-256 c650e6fd4fc0a0fc115b3b27313a260a985aa0fecc95b359c0ab418105d0b8cd

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.14.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.14.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 c5732bc913b406abae1d37a9c004104c84f5c2951f54bd28c41c03ceb2427d4f
MD5 ee80701f0acdfb73074f9f002431453e
BLAKE2b-256 d12d755b028704cf36a82b0064fd0faaab569a7443ff68109f422421ab80500d

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.14.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.14.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 d6504ab3f4cb68df7828878106707bd855eb7dc6e056bdb33ecd6541fc300151
MD5 a46d20dc15cce7431e11165c0e2c63bf
BLAKE2b-256 a51d10d8c23cac936366b68f2130bd33f4c036d11d1124165a99b6697d5505a8

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.14.0-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.14.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 af8cb09fd7916c2e5cfcf34ba3f0061f4fa91bc86251d45f6163f535d6742726
MD5 6e4fcc6fb899093d0ce0d87726a0ad62
BLAKE2b-256 efc90f38329a04074a37bf4c9141643ad6776499dc1d41b4effab969f97c7b12

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.14.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.14.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 b383a631b4e941da047889a4f5bf3e09b91a8c4b32602e28d132a27f66c3b5bc
MD5 08b4575a7df7e24b2dc1b28db11f13ba
BLAKE2b-256 b3f487d83f791d598a40fa35178076fea9228323a221a0d3eda6e9767b558247

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.14.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.14.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 e50401ec78453ea5e09bcab4fb2da46b48f1a1993f1932b17e72c47509f198bc
MD5 5bd90aee73e5c6650f7d369c3de1a074
BLAKE2b-256 b6dbeb4156118ef95f8e517cf5cce5a214c4b1b7264d42f5d1e072320efe49f6

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.14.0-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.14.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 3e6165a5225c63061ce151aafad986b7c5d0ddfe6017a4a74e1651bd223877b9
MD5 94f2a7ca472dc3e53af445972d86592f
BLAKE2b-256 2d8a905271babdb57b0fcb0b9d3dfded921847ae9bb39ae822c022105fdd4b94

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 f12bbb42a6cc0cc0c00ee5eba4c6cc4c68c99c076567b3d8b27be59c7fe2ce29
MD5 45336e18486e0c1483740f437ab5da94
BLAKE2b-256 6b38a33229985ccc9f7964f0e6d439af144541cd02e18937098093ff28fd6fa9

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.14.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.14.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 2f268d99e814e6ea1a6688b1a9a135e42ba1908a18c3a12109c651f21def76a9
MD5 f1a73e641b8044c0a01935a26f8d2bc7
BLAKE2b-256 be9b614e9280adf207b5b682608a3027af14d9f032c637ebe075fe1d346516df

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.14.0-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.14.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 d0d79878155c3c0239ac048f06dd7556e8f8ad3e7000c31f41608c2890676a1d
MD5 2255bdbe45c1c85e234c46085f738e9a
BLAKE2b-256 7f00d74dd38576014b3c17844137ef373e93eca034fd6e2187f4e23c0350f263

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 6bf91d28d0e46b2c3c85a7de6a0cc20ba951ce9a698caf59deb91a97282493cf
MD5 e669cce5db245ed969bddd07ca60963b
BLAKE2b-256 9969d802ff276b4065f83c1c6bd977f63f5add452f96e554136ee2e027d6f444

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.14.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.14.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 fe4a2fc5eda2d7ce0dd80a0ddaa40d24ae196cfed96777ca4d9898926ed85ee6
MD5 83de22cb900c767271d571cb33f3ec61
BLAKE2b-256 3355f11adf45cf98b34a61e84f05dc0bb30ada1a7ce4dea4702ea5ee5361e084

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page