Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust: Generalized Additive Models in Rust

A Rust implementation of Generalized Additive Models (GAMs) with automatic smoothing parameter selection using REML (Restricted Maximum Likelihood) and the PiRLS (Penalized Iteratively Reweighted Least Squares) algorithm, inspired by R's mgcv package.

Features

  • Multiple Distribution Families: Gaussian, Binomial, Poisson, and Gamma
  • Flexible Basis Functions:
    • Cubic B-splines with natural boundary conditions
    • Thin plate splines for smooth multivariate regression
  • Automatic Smoothing:
    • REML (Restricted Maximum Likelihood) criterion
    • GCV (Generalized Cross-Validation) criterion
  • PiRLS Algorithm: Efficient fitting via Penalized Iteratively Reweighted Least Squares
  • Pure Rust: No external BLAS/LAPACK dependencies
  • Test-Driven Development: Comprehensive test suite with 20+ passing tests

Installation

Add to your Cargo.toml:

[dependencies]
mgcv_rust = { path = "." }
ndarray = "0.16"

Quick Start

Python — high-level wrapper (recommended)

GAMFitter is a drop-in replacement for mgcv::gam / rpy2-based fitters. Accepts numpy / pandas / polars inputs, named predictors, per-term k overrides, posterior sampling, confidence intervals, and a serialize() method matching the schema consumed by downstream GamPredictor-style code.

import numpy as np, pandas as pd
from mgcv_rust import GAMFitter

df = pd.DataFrame({"days_ago": ..., "quality": ...})
y  = ...

gam = GAMFitter(
    predictors=("days_ago", "quality"),
    k_default=6,
    term_k_mapping={"days_ago": 25, "quality": 12},
    family="gaussian", link="identity",
)
gam.fit(df, y)
preds = gam.predict(df)
lo, hi = gam.predict_ci(df, alpha=0.05, n_samples=1000)
serialized = gam.serialize()

Python — low-level Rust core

import numpy as np
from mgcv_rust import GAM

# Generate data: y = sin(2πx) + noise
X = np.random.uniform(0, 1, (500, 2))
y = np.sin(2 * np.pi * X[:, 0]) + 0.5 * (X[:, 1] - 0.5)**2

# Fit GAM with automatic smooth setup
gam = GAM()
result = gam.fit(X, y, k=[10, 15])  # That's it!

print(f"Lambda values: {result['lambda']}")
print(f"Deviance: {result['deviance']}")

# Make predictions
X_test = np.random.uniform(0, 1, (100, 2))
predictions = gam.predict(X_test)

Performance: 1.5x - 65x faster than R's mgcv (problem-dependent)

See API_SIMPLIFICATION.md for more details on the simplified Python API.

Rust

use mgcv_rust::{GAM, Family, SmoothTerm, OptimizationMethod};
use ndarray::{Array1, Array2};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    // Generate data: y = sin(2πx) + noise
    let n = 100;
    let x_data: Vec<f64> = (0..n).map(|i| i as f64 / n as f64).collect();
    let y_data: Vec<f64> = x_data
        .iter()
        .map(|&xi| (2.0 * std::f64::consts::PI * xi).sin() + noise())
        .collect();

    let x = Array1::from_vec(x_data);
    let y = Array1::from_vec(y_data);
    let x_matrix = x.into_shape((n, 1))?;

    // Create GAM with cubic spline smooth
    let mut gam = GAM::new(Family::Gaussian);
    let smooth = SmoothTerm::cubic_spline("x".to_string(), 20, 0.0, 1.0)?;
    gam.add_smooth(smooth);

    // Fit with REML smoothing parameter selection
    gam.fit(
        &x_matrix,
        &y,
        OptimizationMethod::REML,
        5,    // max outer iterations
        50,   // max inner iterations (PiRLS)
        1e-4  // convergence tolerance
    )?;

    // Make predictions
    let predictions = gam.predict(&x_test)?;

    Ok(())
}

Architecture

Core Components

  1. basis.rs: Basis function implementations

    • CubicSpline: Cubic B-spline basis with configurable knots
    • ThinPlateSpline: Radial basis functions for smooth regression
  2. penalty.rs: Penalty matrix construction

    • Second derivative penalties for smoothness
    • Supports multiple penalty types per basis
  3. pirls.rs: Penalized IRLS fitting algorithm

    • Implements PiRLS for GLMs with penalties
    • Supports all standard GLM families
    • Automatic weight computation and convergence checking
  4. reml.rs: Smoothing parameter selection

    • REML criterion for optimal smoothing
    • GCV criterion as alternative
    • Log-determinant computations
  5. smooth.rs: Smoothing parameter optimization

    • Coordinate descent optimization
    • Grid search for initialization
    • Works in log-space for numerical stability
  6. gam.rs: Main GAM model interface

    • Combines all components
    • Handles multiple smooth terms
    • Outer loop for lambda optimization
  7. linalg.rs: Linear algebra operations

    • Gaussian elimination with partial pivoting
    • Matrix inversion via Gauss-Jordan
    • Determinant computation via LU decomposition

Mathematical Background

GAM Model

g(E[Y]) = β₀ + f₁(x₁) + f₂(x₂) + ... + fₚ(xₚ)

Where:

  • g() is the link function
  • fᵢ() are smooth functions represented by basis expansions
  • Each smooth is penalized by λᵢ ∫ (f''ᵢ(x))² dx

PiRLS Algorithm

  1. Initialize: η = g(y)
  2. Until convergence:
    • Compute μ = g⁻¹(η)
    • Compute weights: w = (g'(μ))² / V(μ)
    • Compute working response: z = η + (y - μ) / g'(μ)
    • Solve: β = (X'WX + λS)⁻¹ X'Wz
    • Update: η = Xβ

REML Criterion

REML(λ) = n·log(RSS) + log|X'WX + λS| - log|S|

Minimized with respect to λ to select optimal smoothing parameters.

Examples

See examples/simple_gam.rs for a complete working example:

cargo run --example simple_gam --release

Project Structure

├── src/                    # Core Rust library code
├── examples/               # Rust usage examples
├── benches/               # Rust benchmarks
├── tests/                 # Rust tests
├── scripts/               # Testing and benchmarking scripts
│   ├── python/            # Python scripts
│   │   ├── tests/         # Python test scripts
│   │   └── benchmarks/    # Python benchmark scripts
│   └── r/                 # R scripts
│       ├── tests/         # R test scripts
│       └── benchmarks/    # R benchmark scripts
├── docs/                  # Documentation and analysis
└── test_data/            # Test data and results

Testing

Run the Rust test suite:

cargo test

All 20 tests should pass, covering:

  • Basis function evaluation
  • Penalty matrix construction
  • Linear algebra operations
  • REML/GCV criteria
  • PiRLS convergence
  • Full GAM fitting pipeline

Additional tests and benchmarks are available in the scripts/ directory.

Implementation Notes

  • TDD Approach: Every feature was implemented with tests first
  • No External Dependencies: Custom linear algebra to avoid BLAS/LAPACK issues
  • Numerical Stability: Operations performed in log-space where appropriate
  • Extensible Design: Easy to add new basis types, families, or criteria

Limitations & Future Work

  • Smoothing parameter optimization could be improved with better algorithms (e.g., Newton-Raphson)
  • Eigendecomposition for handling penalty null spaces more rigorously
  • Confidence intervals and standard errors
  • Model diagnostics and residual analysis
  • Tensor product smooths for multivariate terms
  • Parallel processing for large datasets

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. JRSS-B, 73(1), 3-36.

License

MIT License - see LICENSE file for details

Author

Implemented as a Rust port of R's mgcv package core functionality.

Update: REML Implementation Fixed! ✅

You were absolutely right - the REML implementation had bugs that caused it to always select λ ≈ 0.

What Was Wrong

  1. Singular Penalty Handling: REML was incorrectly handling rank-deficient penalty matrices, setting log|S| = 0 which broke the criterion
  2. Lambda Passing: Optimization was passing λ = 1.0 with pre-multiplied penalties, confusing the rank(S)*log(λ) term
  3. Insufficient Data: Examples used n=30 with p=15 (ratio 2:1), which is too small for REML/GCV

What Was Fixed

  1. REML Criterion: Now correctly uses log|λS| = rank(S)*log(λ) + constant
  2. Optimization: Passes actual λ values to criterion functions
  3. Data Size: Increased to n=300 for proper n/p ratio (20:1)
  4. REML Search: Uses fine grid search (gradient descent had issues)

Current Performance (n=300)

GCV:  λ = 0.067, Test RMSE = 0.480  ✅
REML: λ = 0.058, Test RMSE = 0.480  ✅

Both methods now select nearly optimal smoothing parameters!

See IMPLEMENTATION_SUMMARY.md for complete details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.8.0.tar.gz (9.7 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.8.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (829.6 kB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.8.0-cp314-cp314-win_amd64.whl (3.4 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.8.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (832.2 kB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.8.0-cp314-cp314-macosx_11_0_arm64.whl (668.7 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.8.0-cp313-cp313-win_amd64.whl (3.4 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.8.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (830.9 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.8.0-cp313-cp313-macosx_11_0_arm64.whl (667.9 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.8.0-cp312-cp312-win_amd64.whl (3.4 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (831.3 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.8.0-cp312-cp312-macosx_11_0_arm64.whl (668.3 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.8.0-cp311-cp311-win_amd64.whl (3.4 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (830.7 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.8.0-cp311-cp311-macosx_11_0_arm64.whl (668.5 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.8.0-cp310-cp310-win_amd64.whl (3.4 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (832.7 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (833.8 kB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

mgcv_rust-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (833.4 kB view details)

Uploaded CPython 3.8manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.8.0.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.8.0.tar.gz
  • Upload date:
  • Size: 9.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.1

File hashes

Hashes for mgcv_rust-0.8.0.tar.gz
Algorithm Hash digest
SHA256 15edb327501e0c62da730b9c03358db451797c664e0c77827a6e616f742108d5
MD5 0b2d478164078d000141372b9624ad90
BLAKE2b-256 435e4fb918c6b04df62516907c41a74059e7c389b98267025a77e1b95f66fe05

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.8.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.8.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 f6dd42036404b2b6b9049be2d92760b98d25e60101c8049dd6becd99316aae45
MD5 1eecc3ced8779eaafe3fcee972dfe6ff
BLAKE2b-256 42e01ba6c9fbffa3bbe28cc56a0a1397ba786dd06719af049f3914655fc281ea

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.8.0-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.8.0-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 02ffc20f1495fabf85cf79cdad73bb21ffa79f5540f8456fbaad901a7157f009
MD5 1a2ec607e8779ebb0e3a8bbcb3c98681
BLAKE2b-256 b7190828634923605fd04149070ed9049e489ac18be480a3a0230f198a20a98e

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.8.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.8.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 3fb7dbed09010d3ca17183348d5454ef05fa9c75c80822aad8731b7c7f0b02ee
MD5 5b484a932a21b459933c450b1f3d5feb
BLAKE2b-256 e43948f58ca72162efdabe33bc7f54eae4501b14930c4d9c6054da358c2651aa

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.8.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.8.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 a554055b1bc55e3c13f42c615146a25dfe2e245a53bcd7c59c6cff42f87a2b4b
MD5 f9501803557790e520fdfe4cee73d691
BLAKE2b-256 6701120c8d68c9ea2e897ba6a15731573a6e0f3f4bdb72e51c9379b5a6cad12b

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.8.0-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.8.0-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 aab143bc6018bebb87fd8ccbaa0be166819d4610a87c28c42d002ced3c10b50e
MD5 94ad6739ced93a40962cac03f4036d35
BLAKE2b-256 7ca58a9b18ac7423bf57cd411996b48536640f96b5bcd6d07d0ef04a0b1cab69

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.8.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.8.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 da0787c6b4660e68a17a7131c957a0e2e091ef7c3d3c299b797d4a72f3ba704d
MD5 d29baf7a12201408486e3537a4e1fd95
BLAKE2b-256 cce369456d5b7ca9fbff43685f796c8578c84fa8bb2f319ff7be97708ebe672f

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.8.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.8.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 64e228572f28d1a71de01ae071e3088b2bef1cb7496ea4e1821367b13603a3bf
MD5 f5974fd471210bfb7f09ccedddfa757c
BLAKE2b-256 9f8ee50764250ffa2c9b603ff912bbd74820b67949230809102a651f481d430c

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.8.0-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.8.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 d37bfbb0ce58bfe7688aa87d18065b0eab334b05ccc5aa698b77a89830592a1c
MD5 8bc872391a3205b94c11c84704c9735a
BLAKE2b-256 b06c8af9477550d4a6829dd6f87b650d3d7a466e85a11f5cdbcfbab841891e72

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 51e3d789584dbe7fc04d69378a2e158410622f2a18e9da456f72b993be9a1368
MD5 dd128197cdb4a9658c54aa19db545b79
BLAKE2b-256 fd4f354278e573e0cf8e82eab019a96cf74efc17a533db0fce0f91fa72b2e6d6

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.8.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.8.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 9fac144f842ef8a6955e6af7a6013504cf8ea4334d2c8ef0b4f6b39d258f0387
MD5 c3acd73232cbc6473d62aa8998ca5633
BLAKE2b-256 5392073f9feb5987a9279ccd78c3bc88abd727830d7fdde192ca432b6c73f5aa

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.8.0-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.8.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 546e25512f0e1de44004f0d0d69f657a4b12779be8fc2420b2d75707f13719fa
MD5 3eeb59252ec0d517ff77637149aa27d7
BLAKE2b-256 2e262b8627b9ad6f91c9cfc716abf0df90d2dc771a782fe310339f08d51b3a4b

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 fc14ceba49986b477f2f6822c899fc662171577856b432b74a8389e40248d509
MD5 0a7728ceebb14e081271cb5ed782e9e6
BLAKE2b-256 a6cddead78c0e1b826ead844b4a872695eebf426671e6351238478b507bf7c39

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.8.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.8.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 87fc9bfd2748ef6c71c025f71fcd422a14020d6325d1f5c0f0933a9d101ba985
MD5 0bdefa9ad27de671005f7a8cbbe4335e
BLAKE2b-256 69d3affd477f34c698336ff6b095f5404d929b22307d1f74091359fe923042ed

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.8.0-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.8.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 586d7c531b3fd71b0006fc90a8ee3f9246df4bcdf055b0de0e931f1a94acaa50
MD5 9f7796ed05ac096d8a029fa10d6411c6
BLAKE2b-256 550b7ae7b13332064029bc0293751814acb2de47d435a4a7baa201a4f3b4f961

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 f65595e581e9a779017f2e686e8bfaf6e2727dad0a34eb206324947bb64647a9
MD5 46e01b554ecb5c295d58fc8f8f4370fe
BLAKE2b-256 c69a3625965a22e00dfbd7b34d76dab9c95762fd0f0795d404dcee86c52918f6

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 d4f12ae85e44772a6bd5907571e8839c9372e4a382f1a09d972086000616d7b8
MD5 8e2d6187713f0322f88c88221d6b9d5a
BLAKE2b-256 7768ed38c00376551a68f29b2fd5096da7a64aba46a1c3b16bb36d3689ade449

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 b276d472f240815f63a0cd54cd9bf91f7d01f30ecdea9fc98a7a7fc707b8bf9e
MD5 a07c6d129e9749836d346562dd8e4690
BLAKE2b-256 e166113eeb6df8c15aed02dab169887e62d57fca42a997985b18b18d63accf51

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page