Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss; OOS presets documented in docs/qgam_oos_presets.md

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.20.0.tar.gz (13.7 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.20.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.20.0-cp314-cp314-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.20.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.20.0-cp314-cp314-macosx_11_0_arm64.whl (922.5 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.20.0-cp313-cp313-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.20.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.20.0-cp313-cp313-macosx_11_0_arm64.whl (921.3 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.20.0-cp312-cp312-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.20.0-cp312-cp312-macosx_11_0_arm64.whl (921.8 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.20.0-cp311-cp311-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.20.0-cp311-cp311-macosx_11_0_arm64.whl (923.7 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.20.0-cp310-cp310-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.20.0.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.20.0.tar.gz
  • Upload date:
  • Size: 13.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.20.0.tar.gz
Algorithm Hash digest
SHA256 cde37f36d2b486307d198463502aa53978ac3f689d51044c6323c670cdcda03f
MD5 44f632a0611126331902e88148de16a7
BLAKE2b-256 ac81d6231b39d28259cf1e52d8e72731271bb7fd675eea3a40a0347b2a485ad1

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.20.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.20.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 4acceb83a39e39cd7e9985d536be20bcb89caef89e90c9721411d1a4f428f146
MD5 c677759d3902216d61d05265b15a9bd7
BLAKE2b-256 6a2246379aa2272bc5e606d7c811d53f4bc8e3e1215bf8cca14ce3a24aea2689

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.20.0-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.20.0-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 54a608de39c38e870eab3b90f750c465822396ad5ac16a052950fcf9a2c83725
MD5 b6001db294f36cb0363d5ae9d2437803
BLAKE2b-256 a6ed96e484cd4ef7bbe3d720ff68ff0b9fcc413d109586745c2e14007c81c4b6

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.20.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.20.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 0e47012dea259dff97a973da7127148aa0b35908764799e4da07404660b3d9ec
MD5 d2c6cbe9ea1890192bf08b525153b50a
BLAKE2b-256 7aa3d01c6d5ff13e888bdc9496a034e435fd819d7ad7c1f83f215dd9b80065ec

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.20.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.20.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 25bf5da1dfbf313e1a805228c05a6483cdd6fd203a68b494c54139ad33abaa87
MD5 641e5bdf2ab2b55f80a4299923d4b8b3
BLAKE2b-256 a091c5371106b56d5a1071c67f0d4e3f6551e67fa7fd68a1471503441d33e3ce

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.20.0-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.20.0-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 152dfe6ffd414ac64e38f531c84c6708112009a19f5c1540dce22dd7317d6162
MD5 a4577f5d73447ae1889c43693c8af38d
BLAKE2b-256 d3fc20934b0da64befb0c1a4caf4709c2fefd1d31a3ad161b47c59ede5e02daf

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.20.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.20.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 f6159a338e82171e4677d535963d2119e324030330c1f5ac45cc7b46139e6464
MD5 95a6f26459fcd87fdaef0055c7d10e04
BLAKE2b-256 b8cbd2a034f9372bf4d536b2e5d390ad0ca3fa894a2587ed85ca4167be500221

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.20.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.20.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 e6d092f5a870b6968e003064da525bdab0a6b2c8a5678a89e263b5bc89bfbd27
MD5 7e287631b56a2224688b675c957da25c
BLAKE2b-256 b8b5cc519fccbc2753cfebb9b4a7690459076238e0208f6901d5f4d96a18aefd

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.20.0-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.20.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 2e40b8ca6811e957a60d1694710c20416f17899694c056bf16b6d4edf0ed8d09
MD5 29b3db4bb00f2974ddd83b4460c44dce
BLAKE2b-256 819f95cfbfc7b523db9de7cbf7489a88f8f16630b8cc78f796e6eb7062ff8bc8

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 86fd1b59de65a03f33f7a8832c12ba6a7f37370c54068768f8899acbea8bc005
MD5 44453f9d1298ae1c90536cf058839fa7
BLAKE2b-256 5324f6d50d01f08adbb9608c7885d3d346696462859b5cc3e6bb84699eff95f5

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.20.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.20.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 1754c0dd61fffd276c2279289fba7427ce5484bfd840b7501384a9e632ba50a0
MD5 a3a91d14d875c6c29a038c05de5c3384
BLAKE2b-256 69483072d06d9326e105bca102f8e9a43bc7943c202fc871dc930d1aafc19763

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.20.0-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.20.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 daf4635ed955d5a03b0856dfce47f71f6ede3b70e4088de5fe1d54322bc7365e
MD5 ebaabbd8f3185661c812d838a7c12053
BLAKE2b-256 b4093f4c4262bf4591d4f72b1e170ef903b3e83ba7817b7661d012c4309d9b81

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 f6dbd369ff5d515b61e1963291aa4cf3f8ce0ceee9cc2daedf22ecf10ec8e3dd
MD5 ec0ed804cda36b6b5ff89faaf8a6b501
BLAKE2b-256 7056b5e591ba7ea440cfba99ac9db6559b89ead35c6d334647246c4ee3725d61

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.20.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.20.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 2724257eb4a768d3676b206083612a92fa5632983338f1438052afef473b9b12
MD5 9e98527018d5910ffb7c935e2ca62c5a
BLAKE2b-256 af17761ca7a0d4491ff65266ae65dd675f8098c68f058ae7e0e0d433b8bea031

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.20.0-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.20.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 55ae35e5ce31f892afeb647518c789fa4dd4aac2c8f0161bf3d3089fc0df31fd
MD5 3ec17223a9eca0c2d3f8b64b4ae0a66e
BLAKE2b-256 1f7886123a1bdc10aac26529c413afbc1a9e47cde0af739dfaa8bfe77dc46029

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 e41a290f9b806b56a402c4695a76d49c2bea41d944b2a4b0950d5560054ac958
MD5 d9557889224a1a763f304e2c6be8c51d
BLAKE2b-256 e40186e63e9f19809e4768af4026df22c0faee933d6fdde80dbe3d228f69ab3e

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 88fc65295914aff5c70e26658777cec75262702dddf67a1b07ac7d32fb6591b8
MD5 b3461060eaf0d9ebd5aa7b962c153819
BLAKE2b-256 7206f72ac38c8c87ba32c58f81fa97aca0ff27ae29ad3ceb28a69fe1a27306de

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page