Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss; OOS presets documented in docs/qgam_oos_presets.md

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.17.0.tar.gz (13.7 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.17.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.17.0-cp314-cp314-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.17.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.17.0-cp314-cp314-macosx_11_0_arm64.whl (890.9 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.17.0-cp313-cp313-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.17.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.17.0-cp313-cp313-macosx_11_0_arm64.whl (890.1 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.17.0-cp312-cp312-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.17.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.17.0-cp312-cp312-macosx_11_0_arm64.whl (890.8 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.17.0-cp311-cp311-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.17.0-cp311-cp311-macosx_11_0_arm64.whl (892.5 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.17.0-cp310-cp310-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.17.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.17.0.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.17.0.tar.gz
  • Upload date:
  • Size: 13.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.17.0.tar.gz
Algorithm Hash digest
SHA256 c130601965ff289150de52d04c1569fca7ec316738886a4fe6bf70bbb6633ffd
MD5 dd38b18d78f926e3122a6c1c537f2050
BLAKE2b-256 60967c38d492fd832eb6e53795395746b602101d5fe1afe5e8f8f1344dab9bc5

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.17.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.17.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 47e2deb0080f2f6b5ad6098a7f3ba419fe772492fe5718a618cebfabb4d6b7bd
MD5 10e3da10cb3553093df9dd17e3fddea3
BLAKE2b-256 bb72143f90ae636a0f4da2d5080aa2e1e92f81b47fee66c1f487bbe7cd189784

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.17.0-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.17.0-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 9192241b452e0d4d222816d56cf37f703757fa0220b556b31c6f0f75582bf352
MD5 443b3dddc33fdb8caaa969c47ece65b3
BLAKE2b-256 c842d4d3cc0c3b8b6cf4bf240166664822920cb083cd28577d43c218c754b9a7

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.17.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.17.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 ce6422d4d0140e4d91c1be967480ee609fc1e0fb5a36a82e71e0d5a2e85a3d43
MD5 eb0b33b677065bca98c2befc62d45bc3
BLAKE2b-256 87759f5cadc520b573d5af0c5156a45d064a190d4e33c2cc7742d3912a50d3e2

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.17.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.17.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 e131193a4c58078ded250d21c3d9a94b2c319a1f031b75bad155c8bcfcf18575
MD5 34acfadc723fe1e86ed4d10875c24b95
BLAKE2b-256 1a6d28ada9098cf773bbf5683a0c9c186711efcaec0e0dfeed70464b359c4c54

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.17.0-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.17.0-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 8f760c4242bd8e129721a15264613061b91d3a959c168d4a4e1abac2deabd13d
MD5 cc98af098f9c27d2bade666350e721ef
BLAKE2b-256 013994d3861590256ca4d357af62277ea5f956f0dcf38b5f67f53bb06f85de90

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.17.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.17.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 078f5aa17f2ededd3913140f4f75eea6b498c539148c211f83f42577a07dd59a
MD5 ca4eccd6704407c6b421189f68d5f249
BLAKE2b-256 7c1376f8adb2af61e21a266b06a4ad368b207ea4aa06acd9bbe930b89f00e7d3

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.17.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.17.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 664d81e17ae676c40c5afeccb16a692533005c04359724c84e5aeb2074464df1
MD5 f1b9161294f5b65aebd9cf4772ca3ed3
BLAKE2b-256 9363139129ee99d606ca69056ee48242c4b20ca338383a58fa846e98fff705e0

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.17.0-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.17.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 505f53868d11ca7e9538a4352127a19477cc342886fac46764b257dbe50f59b0
MD5 beca155df63b6244e73e8c0308b9a223
BLAKE2b-256 71d6e4d46927f220fe43a185c567b5b6d417d682dab4a422b0ae12c93443dd14

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.17.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.17.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 491b320b9a61ac66171c497f16c21f393bc6998bbc845eed6eb228b8409c6614
MD5 28c2a6051d0b69efda013fef75260d7c
BLAKE2b-256 c9e4169ff02fbd9937b950bf73afa8b6bc16535b4028ce97d84541d50115d219

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.17.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.17.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 3cbd21ebb3b8724050278b95fcaf16eeeb10f6784f5dae63208441d38960971b
MD5 19020b510a83521179e3293ea57a50fb
BLAKE2b-256 afde0d1a274efcd8f853b5982121664e208e211b503428805a4696d9ccea6b40

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.17.0-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.17.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 e6a6ed00537af1e8d259cb77a2edb15b80af9101ae173684de7cad4b3261cbb1
MD5 b657ec4131963ea29ddfe08ce493df48
BLAKE2b-256 55757b86d36b0a7dfad9d3021c88e0dbf4ebcfde7acc9b3038198bddf20799e8

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 77033d90706540e0c4f1f1c92a91cc05930daf512b07116e7413b3560ab51cd9
MD5 f18e5e6dd0047fd0bf8f9f17edc9f509
BLAKE2b-256 0917169ad4d97547bca4da3c306d88420c4b368ca09e50b0773822b4f22e32ce

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.17.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.17.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 a739b731338ad5a73293337219a907e2c20138ba4f83374263203b94d8b92280
MD5 5ad9b96f5e8d891fed6cde5ae255ea2e
BLAKE2b-256 025536cca3e52e2a564fb589c78fe24fcb2b9e346daa78ca4247f1fb22073c7e

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.17.0-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.17.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 619b4d286d7b234dc3282108af9dc3d85417748cb13a49e2bf2abbb480615082
MD5 5fbeaf7536901d5efe2e05f3929691ba
BLAKE2b-256 3dcfae55ef81503b169a4b3606edae8a62b13796dca3dda647825fcd9c93995d

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 f53d0f0b122ea891da0cb5fd740b9adeeb81d6585fead3b6da4ca778a3d69931
MD5 48a6dd04f64c751ac564fb28c66904d5
BLAKE2b-256 d2abaf563c61592157aa3524c665ecb66ad13ff47c19f1e252e45db1bbdf727b

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.17.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.17.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 1614b985bc4760cd2f55646a1356fd0d200358976b33075e7652ca96763f1e09
MD5 fa050f901eafb03ee57143bc7bdd6523
BLAKE2b-256 fe84cffa7ef33dbd0e62e0b33f7ff9f34bd619cabd48ac0c13c30257f46c4fe5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page