Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.12.0.tar.gz (9.8 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.12.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (971.5 kB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.12.0-cp314-cp314-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.12.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (973.7 kB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.12.0-cp314-cp314-macosx_11_0_arm64.whl (792.1 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.12.0-cp313-cp313-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.12.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (972.4 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.12.0-cp313-cp313-macosx_11_0_arm64.whl (789.9 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.12.0-cp312-cp312-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (972.5 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.12.0-cp312-cp312-macosx_11_0_arm64.whl (790.6 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.12.0-cp311-cp311-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (971.7 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.12.0-cp311-cp311-macosx_11_0_arm64.whl (793.8 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.12.0-cp310-cp310-win_amd64.whl (3.6 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (973.6 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (975.3 kB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.12.0.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.12.0.tar.gz
  • Upload date:
  • Size: 9.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.12.0.tar.gz
Algorithm Hash digest
SHA256 7ecbd51e6f44e12688fd7239b066da1b28d58d121ff7af12f4e259602e12165b
MD5 61a5b1c48d967e769129384f9f5d46d8
BLAKE2b-256 af6daafb0e28cef32fe581c300cf8bd0f84b3d507226f72a4792a72bad1841ee

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 9670e8526db11bf4e60671162999f38d1c9745e67fecfa7a348fe6207cd7e31c
MD5 d72440c15ff548c860d28d2044fb9c22
BLAKE2b-256 1db31cc22a6448d84b105c9c8ef65af4f598f8071980549d106ea6e4c3788534

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.0-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.0-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 65ec7b53cb2145897894f437775f10b0193a1e239150ed7d441ad60313d3dfdd
MD5 ada3df881e3d4e6d7878654cde82f369
BLAKE2b-256 1428cf14a32af9415cac7c98812a85f225084669f1975a8e08500dd08ad9cc92

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 b82c6e6bd82d7dd4fef2bc46d7a380a641c9a381717a37b968e60c551536f2c9
MD5 e0804ec5aeec563fb9661c7882a34ce6
BLAKE2b-256 e8c0e12fcc6c87e01e15b462496b80864e031d93b411d5262d209edd34965f7d

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 6fa17efa472127aafb4d2aa14ebdd9376b3507fa309376f6796803937ed31cba
MD5 094524321367a6c679303212a3b28e3f
BLAKE2b-256 ed6e1c5c20c7b22e787fbf9c29245293d92dd562446514bd6e8af91c8ebb9f7f

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.0-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.0-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 1cb9fdc87b9617539ada426ee749ea2b3dc2fd66f14e03d2622e0469f0cfd0d1
MD5 0c081f23d8a10801a8dec5cceef29a30
BLAKE2b-256 db103376e113867b05b2cc78159095ee60d9adfdc85b62d27c6e05c44474c753

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 e9ce0e29779df70a05c8f55bc33a01c956cb6f34fd2198967fd8b619e7bf3b4d
MD5 76c91286072214c15ea0966f7e02cd64
BLAKE2b-256 e223b3468278eb6e1c81166dade62ee7f0080572e02fa8220f215fa898b2d918

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 513306476490995b90d8332457805313a784e81264a1f33aeea45d3ea4a0085d
MD5 3426edbc76234be7de6e2de6065a07f9
BLAKE2b-256 20bbdf350e75d9bb7e2bf8ab7366919bb850d3dbea93ecbeca636cb5fa10a8cd

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.0-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 9d215a77a05a28a842a322f39faea87a095208e2489741b807a68e238d94ed6f
MD5 e0c5692e04028e71abe3206407d9e752
BLAKE2b-256 20d2821c75d4f3ffa192b1eb4ee6bf1207401ed9827ad34c80216b9597401a92

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 c1dffbd2c0746c56173ba63b4afc6ec8ed2be43ff311627393f22a0301146bee
MD5 f89bfb93b91520f4dcdf4747dab51a81
BLAKE2b-256 73e9ace1eddda7da8ca511495c0a10f2e14a93f13e2dc94ab9f638ab1355db7a

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 7da79655a19444662d55d7c34c8c3666545739d8a1306c9c47adfcf87831159e
MD5 fed17173863c5d0915efcbc91749bfb9
BLAKE2b-256 93e5f8c89341f7d71bc21c42ad24e5845ad6ec4978cd920bcd9cf158104012e2

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.0-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 11efb0321e2c89a2ae0d7d4f09dedf50bbb8b6071842df8321f8414d0d522c32
MD5 4248f42530e26dac41c87b27db0f95d0
BLAKE2b-256 f351183638e069588a1d5ddb5b77ce1eead9d1425fd3340aa54980a2abc48ae3

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 5d606325358b924d2ce2a1358b5d0711fed4fa94c238baed5391e71a7f0f872e
MD5 23a67627f3eb3ac7a46ce20783616c4c
BLAKE2b-256 1079603b3c1d5a74bd4e9ba6c126112d8d1f6ede45ab4fe50fc48605389b3ddc

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 72662187b9ef156e32cc744fd44043d5c69d51649b73ba5e31eee81a182410bd
MD5 df8ccb24a53c5978f5bbc973ea14e07a
BLAKE2b-256 22f0ac22c1ece383fecfab83a9fbe9702afab7dbd6b8fba472cb79b6a604450e

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.0-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 d8fa0a86fcdf7da5f10232a83fe68353cad212bf03da04272bf971521ebf64ab
MD5 1a948b2fa181a9e659e3752c98421c08
BLAKE2b-256 4454ae6787cae86106580e1d4387f3b82b98b0caa9454f06f75cab8b205e48ef

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 6dfdc7e5e7306c12081109e2577e3c37079b9f1532a0751788d08ecfd1e834b0
MD5 02817e674d01bf7b5692d77073281145
BLAKE2b-256 a833d329cf458e01825da00beebd8066f028ccc74301464066ab6d0f5419df91

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 09bf63716ef5eb57f1b055f5b87a290449995b07238ab9f10c57fe9e7c64cd93
MD5 b193d0c1d4bae3f5ae9e185fe426fccf
BLAKE2b-256 2f915b71251905e2ee0dbf2cffb0d1246e8dec77c8a5031724028d212c3f752d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page