Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust

A Rust port of R's mgcv package — Generalized Additive Models with automatic smoothing-parameter selection (REML / LAML) and PIRLS fitting — with first-class Python bindings.

The numerical core targets byte-for-byte parity with mgcv on common families, and the Python API is designed to feel like sklearn: a single Gam class, named-predictor DataFrame inputs, fit() / predict() / score(), posterior CI, and subset views for marginal analysis.

Latest release 0.11.0
Parity 554 / 0 / 0 on the mgcv comparison battery
Python tests 211 passed, 1 xfailed
Headline perf 2d_gaussian_additive_n50000_k15_cr — 394 ms → 97 ms (4.05× vs. mgcv)
Python wrapper mgcv_rust.Gam (canonical), GAMFitter deprecated alias

Why another GAM library?

  • R-equivalent answers, Python ergonomics. If you have R / mgcv users producing models and Python services consuming them, this gives both sides the same numbers from the same fit. The serialize() method produces a portable artefact for deployment.
  • Fast. Aggregate time across 80 parity fixtures: 2.0 s. On the largest Gaussian fixture (n=50,000, k=15), mgcv_rust is ~4× faster than R's mgcv.
  • Real CIs, paired diffs. predict_ci(X) returns (mean, lo, hi) so you never have to "+= intercept" by hand. predict_diff(from_X, to_X, level=...) gives a paired-posterior CI for the difference between two predictions — strictly tighter than the naive predict_ci difference.
  • Frozen serving view. GamPredictor(gam) is an __slots__-locked, inference-only wrapper with strict input validation (feature_names_in_ enforcement) and a check_against(gam, X_sample) round-trip assertion for deployment safety.

Installation

The Python package is built from this repo via maturin:

# Clone, then in a venv:
pip install maturin
maturin develop --features python,blas,blas-system --release

After that, import mgcv_rust works.

For a Rust-only consumer, add to your Cargo.toml:

[dependencies]
mgcv_rust = { git = "https://github.com/AlekJaworski/nn_exploring" }

Quick start

import numpy as np
import pandas as pd
from mgcv_rust import Gam

rng = np.random.default_rng(0)
n = 500
X = pd.DataFrame({
    "x0": rng.uniform(-2, 2, n),
    "x1": rng.uniform(0, 5, n),
})
y = np.sin(X["x0"]) + 0.3 * (X["x1"] - 2.5)**2 + rng.normal(0, 0.1, n)

gam = Gam(family="gaussian").fit(X, y)

gam.predict(X[:5])                  # response-scale predictions
mean, lo, hi = gam.predict_ci(X[:5])  # 95% CI, response scale
gam.score(X, y)                     # adjusted R²
print(gam.summary())                # mgcv-style block

That's it for the simple case. Read on for the curated tour, or jump to docs/GETTING_STARTED.md for a step-by-step tutorial with worked examples.


Tour

1. Build the model

Gam() accepts pandas / polars / numpy. If you pass a DataFrame, the column names become the predictor names; otherwise you can pass them explicitly.

gam = Gam(
    family="gaussian",           # gaussian, binomial, poisson, gamma, tweedie, nb, t-dist, scat, ...
    link=None,                   # None → canonical link for the family
    k_default=10,                # default basis dimension per smooth (mgcv default)
    term_k_mapping={"x0": 25},   # per-predictor overrides
    method="REML",
)
gam.fit(X, y)

After fit, sklearn-style attributes are populated:

gam.coef_              # full β vector (intercept first)
gam.intercept_         # link-scale
gam.intercept_response_ # response-scale (= mean(y) for identity link)
gam.feature_names_in_  # predictor names, fitting order
gam.n_features_in_
gam.lambda_            # per-smooth smoothing parameter
gam.edf_               # per-smooth effective degrees of freedom
gam.k_, gam.bs_        # basis dimension / basis type per smooth
gam.vcov_              # posterior covariance of β̂

2. Predict on different scales

gam.predict(X)                       # response (default — same as mgcv predict(type="response"))
gam.predict(X, scale="link")         # linear predictor, no inverse link (mgcv type="link")

result = gam.predict(X, type="terms")  # decomposition
result.intercept                       # scalar, link scale
result.contributions                   # DataFrame (n × m_smooths), link-scale, centered on training
result.total                           # response-scale ≡ gam.predict(X)

Invariant: gam.predict(X) == inv_link(intercept + result.contributions.sum(axis=1)).

3. Confidence intervals

mean, lo, hi = gam.predict_ci(X, level=0.95)       # response, 95% CI
mean, lo, hi = gam.predict_ci(X, scale="link")     # link-scale CI
mean, lo, hi = gam.predict_ci(X, level=0.5)        # 50% CI (narrower)

mean matches gam.predict(X, scale=scale) exactly. The CI is built from n_samples (default 1000) draws of β ~ N(β̂, vcov).

The legacy predict_ci(X, alpha=0.05) → (lo, hi) form still works (emits a DeprecationWarning) so existing callers don't break.

4. Differences between predictions, with paired CI

# Per-row diff: to[i] - from[i]
diff = gam.predict_diff(from_X, to_X)

# With paired posterior CI — much narrower than predict_ci differences
diff, lo, hi = gam.predict_diff(from_X, to_X, level=0.95)

# One row broadcast against many
diff = gam.predict_diff(baseline_row, candidates, broadcast="from")
diff = gam.predict_diff(candidates, target_row, broadcast="to")

Identity-link only — for non-identity links the response-scale difference isn't linear in β, so the closed-form CI argument doesn't transfer.

5. Subset / marginal views

gam[name] returns a view that includes only the listed smooths (and optionally the intercept). All predict methods inherit it.

gam[["x0"]].predict(X)                              # marginal effect of x0, response scale
gam[["x0"]].predict(X, scale="deviation")           # link-scale, intercept zeroed (sum-to-zero on train)
gam[["x0", Gam.INTERCEPT]].predict(X)               # marginal effect + intercept
gam[["x0"]].partial_effect("x0").plot()             # what the smooth looks like

All sklearn attributes filter to the active features:

gam[["x0"]].coef_            # just the x0 block (+ intercept if active)
gam[["x0"]].feature_names_in_  # ["x0"]

6. Plotting, summary, score

gam.plot()                  # figure with one subplot per smooth (CI bands shaded)
gam.plot("x0")              # single-smooth axes
df = gam.partial_effect("x0", level=0.95)  # underlying data: columns x0, effect, lo, hi

print(gam.summary())
# Gam summary  family=gaussian  link=identity  n_obs=500
#   intercept (link)     = -0.0123
#   intercept (response) = -0.0123
#   scale (σ²)           = 0.0101
#   deviance             = 4.8732
#   R² (adj)             = 0.9612
#   smooths:
#     s(          x0)  k= 10  edf=  6.13  λ=2.345e-02
#     s(          x1)  k= 10  edf=  3.07  λ=4.890e-01

gam.score(X, y)              # adjusted R² for regression; accuracy@0.5 for binomial
gam.predict_proba(X)         # binomial only: (n, 2) [[P(0), P(1)]]

7. Serving with GamPredictor

For deployment paths, wrap the fitted Gam in a GamPredictor. Same API surface, plus strict column validation and a round-trip assertion.

from mgcv_rust import GamPredictor

predictor = GamPredictor(gam)
predictor.check_against(gam, X[:50])   # raises if predictions diverge
predictor.predict(X_serve)             # ValueError if any expected column is missing
predictor[["x0"]].predict(X_serve)     # subset view returns another GamPredictor

# `__slots__` blocks attribute leaks — once built, the bound Gam's
# coef/vcov/schema are the contract.

This structurally closes two production bug classes:

  • Column index drift when the serving DataFrame's columns differ from training: feature_names_in_ is enforced on every predict call (re-ordered → realigned; missing → ValueError).
  • x-grid clamping in interpolation-based predictors: GamPredictor recomputes the basis at the requested X via evaluate_lpmatrix, so there's no pre-computed grid to clamp against.

Families and links

Family Default link Other links Notes
gaussian identity bs-spline / cubic-regression / random-effects bases supported
binomial logit predict_proba() available
poisson log
gamma inverse log log-link via link="log"
tweedie log tweedie_p= (default 1.5); fixed-p
nb / negbin log nb profiles θ; negbin is fixed-θ
t-dist / scat identity df profiled if not given (df ∈ [2, 100])
quasipoisson, quasibinomial log / logit dispersion-aware
quantile identity see mgcv_rust.fit_quantile / fit_quantile_lss; OOS presets documented in docs/qgam_oos_presets.md

Architecture (Rust side)

File Responsibility
src/gam.rs GAM struct + fit / predict entry points
src/pirls.rs Penalized IRLS inner loop
src/reml/ Outer-loop REML / LAML optimization
src/smooth.rs Basis functions (cubic-regression, B-spline, random effects, tensor products)
src/penalty.rs Penalty-matrix construction
src/lib.rs PyO3 bindings — PyGAM, compute_penalty_matrix, newton_pirls_py

Build features:

  • python — enable PyO3 bindings.
  • blas / blas-system — link against system BLAS for matmul-heavy paths.

Parity and performance

Run the parity battery against R/mgcv:

pytest tests/parity/ -q
# 554 passed, 0 failed, 0 xfailed, 0 skipped

Microbench:

python3 scripts/python/bench_step_blend.py 5

Headline (2d_gaussian_additive_n50000_k15_cr, identity link):

Time
R mgcv ~394 ms
mgcv_rust 0.11.0 97 ms

Status and limitations

  • Joint (ρ, log φ) outer Newton not yet implemented — the binding constraint for closing the remaining performance gap on dispersion-bearing GLMs. Tracked in mgcv_rust - Backlog - Next Numerical Steps (Obsidian).
  • predict_diff is identity-link only; non-identity raises with a workaround pointing at get_posterior_samples.
  • Auto-k tuning is opt-in via Gam(auto_k=True). Default is a single fit at k_default=10 (mgcv's default), with term_k_mapping overrides — closer to mgcv's "tune k, run k.check" convention and avoids hidden multi-fit costs.
  • sklearn BaseEstimator mixin (for Pipeline / GridSearchCV) is not yet wrapped — soft-dep, deferred.

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. J. R. Stat. Soc. B, 73(1), 3–36.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.16.2.tar.gz (11.9 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.16.2-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.2-cp314-cp314-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.16.2-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.2-cp314-cp314-macosx_11_0_arm64.whl (870.5 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.16.2-cp313-cp313-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.16.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.2-cp313-cp313-macosx_11_0_arm64.whl (869.6 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.16.2-cp312-cp312-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.16.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.2-cp312-cp312-macosx_11_0_arm64.whl (870.3 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.16.2-cp311-cp311-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.16.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.2-cp311-cp311-macosx_11_0_arm64.whl (871.7 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.16.2-cp310-cp310-win_amd64.whl (3.7 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.16.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.16.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.16.2.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.16.2.tar.gz
  • Upload date:
  • Size: 11.9 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.3

File hashes

Hashes for mgcv_rust-0.16.2.tar.gz
Algorithm Hash digest
SHA256 369432911dc4240c82496d2bbe0370667bd940171db2a6a33b8ba623f8ccb94e
MD5 1a5f2de991183b73c41325558ec9089e
BLAKE2b-256 c8aabdb57886705008ff166554d8ad08ddcaba4c223f7469c74df7ab73489acd

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.2-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.2-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 c288651b13b1e55bf1d287e511751327dfd47532a00131d8271141306ce01803
MD5 13ebe3abc9ba32aa04da2399940ff805
BLAKE2b-256 eecd30778993fdeb91dc19af52e65d0679a86e01cedaf0156ccd8e6d34c66015

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.2-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.2-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 ff0a312653ab41b19bc28f5f68e046b0083b6fed1bf473090bc222210e8bdb47
MD5 e4cab3378fd001599effc05ecfc3868d
BLAKE2b-256 404acba9390fb1c10751fc263a254265261375d37a6d663f38208f5c7ef7e98c

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.2-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.2-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 00385b30aaf1c2529056cd572a7638f03fe65712c6d5b6eb3ec9e46ec93a17e7
MD5 51cbbd6df2862fc086f06d9e819a3920
BLAKE2b-256 457a33f4dfa9b94deae89c96631cf27a041cc99192dc57f7b53e2cc791aa510f

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.2-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.2-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 ef2e3c508d76953f4f482480bc22df7d476fdbe43622244ed3f4b895accc1f3a
MD5 0bc3da922000a9287c330298f2f31e86
BLAKE2b-256 4476eac9d950a2483660a96d7d320fe23756d1ccca42b986a2f43914e0d0b4e1

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.2-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.2-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 4910d302d79da9c2e7d6ebee22888d0255da753a4a6003f14fef8c93027f483b
MD5 61ba3239eaa4008d53714be20430b720
BLAKE2b-256 073f2d69a6b91ac2a51920f856dc9f28e7a639df82873f1aa78d885d6d6e7ddb

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 c51c61daf6974715de0eb5a817c933e046319ea93e500bc28fd60436e5063439
MD5 7e8dceb8433d8b12c948b54a9e6b7075
BLAKE2b-256 2bb311534d42051fedbcc984bb457b1754d6dde21523881ec83980e7610fc721

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.2-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.2-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 b18a879bc688cbbaad74a067e8a447ce8be010c6b6f30740389ddd83b5d6bfba
MD5 bc1c654ff1d829671dbe1f6269fa1bd7
BLAKE2b-256 9866df6a5333d2c847656138f0dbdc38869014ffada16402c8d528267c928d53

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.2-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.2-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 4f369d10d0eb94a33d4b08f8f6ec92a3c595409573291451d258efe7d43cd837
MD5 daf99b1bcb3aa39060e9daf76df8e978
BLAKE2b-256 6a7e95bd61593531cfb794d1625d9462b36dfb98ba16da5c05405a0d76c60b24

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 e438a399cd13fcc230a440f1b435ad519821a16c53308f7df400bf3262ee5437
MD5 59c73ee3b01be2186eb571c88447e5ba
BLAKE2b-256 840a7e269d569d1bf41b3c171d2952384b8fbd4198f843f75aadb9548a65f111

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.2-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.2-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 8d12c0ba57d428de619e2e63f83b989ff76a6d3927924ef342ad471a51f49be4
MD5 a1979362779bf2507805b605f8cf2355
BLAKE2b-256 d3c4b7f5c4caecb3ec82ebda7a82b88510aba97999d018ea0ba3ca3f891f922b

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.2-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.2-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 32beebc86f5ff45a0ed18c95bbe4a134f375ca4f484d889d022c4e305454c73c
MD5 841d7ec3bd8518f2b93ebe230a06779e
BLAKE2b-256 58503050fe940788bd7b34b08d567082881815bf99d95d2d954f3a64f27f787d

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 8819e68c579504b9978bc086cb61e62564b28e47014e0f1876f313dbb6034203
MD5 3083c5c99fe4b91ce76520589991621b
BLAKE2b-256 51edbae20b802585c1c0bb6b5dfafb17b7eafa0840b2b18700de1767fb26813d

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.2-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.2-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 dc3c46ed11fc2a56cd9d5ef1e5fa7c70c3d2493ea9f9a9eb4ccddac192037586
MD5 c6636cc8ef2ecf56a5f9ea4c2342da9e
BLAKE2b-256 9b88dd908ae5984036ca39ad616cd1c14859b27e5ba0a212fc7eb92dcd832735

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.2-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.2-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 f2fa172037dab48f01c3f46d2cba76616da9697adecf7460584bc14af23197e3
MD5 798125dacda65e0db768bbc0c24ccbe0
BLAKE2b-256 b7197da66609eeca690976b5be76f0d736dd82a9af982407691bf9f54204ec72

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 6187e9232baf8d15fcbc7eabca41b8cb44409cbda3d55d494eaeb2ad3f6badc3
MD5 24a03e6e271b3b2c7e89b41d16338c48
BLAKE2b-256 d5cdc259ed1111babb0849cafc0ac1035ba39db2166c22c30431b7bb2adebc5b

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.16.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.16.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 bb7f0525886fb4ba05105fe53def494d143c759c551c33607af2efa9e071b31e
MD5 6bb2161819620883460d611ecc8fb73b
BLAKE2b-256 76050818b11e501cd8d131096bd929e30b7312d949aba742cced4335c213d081

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page