Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust: Generalized Additive Models in Rust

A Rust implementation of Generalized Additive Models (GAMs) with automatic smoothing parameter selection using REML (Restricted Maximum Likelihood) and the PiRLS (Penalized Iteratively Reweighted Least Squares) algorithm, inspired by R's mgcv package.

Features

  • Multiple Distribution Families: Gaussian, Binomial, Poisson, and Gamma
  • Flexible Basis Functions:
    • Cubic B-splines with natural boundary conditions
    • Thin plate splines for smooth multivariate regression
  • Automatic Smoothing:
    • REML (Restricted Maximum Likelihood) criterion
    • GCV (Generalized Cross-Validation) criterion
  • PiRLS Algorithm: Efficient fitting via Penalized Iteratively Reweighted Least Squares
  • Pure Rust: No external BLAS/LAPACK dependencies
  • Test-Driven Development: Comprehensive test suite with 20+ passing tests

Installation

Add to your Cargo.toml:

[dependencies]
mgcv_rust = { path = "." }
ndarray = "0.16"

Quick Start

Python — high-level wrapper (recommended)

GAMFitter is a drop-in replacement for mgcv::gam / rpy2-based fitters. Accepts numpy / pandas / polars inputs, named predictors, per-term k overrides, posterior sampling, confidence intervals, and a serialize() method matching the schema consumed by downstream GamPredictor-style code.

import numpy as np, pandas as pd
from mgcv_rust import GAMFitter

df = pd.DataFrame({"days_ago": ..., "quality": ...})
y  = ...

gam = GAMFitter(
    predictors=("days_ago", "quality"),
    k_default=6,
    term_k_mapping={"days_ago": 25, "quality": 12},
    family="gaussian", link="identity",
)
gam.fit(df, y)
preds = gam.predict(df)
lo, hi = gam.predict_ci(df, alpha=0.05, n_samples=1000)
serialized = gam.serialize()

Python — low-level Rust core

import numpy as np
from mgcv_rust import GAM

# Generate data: y = sin(2πx) + noise
X = np.random.uniform(0, 1, (500, 2))
y = np.sin(2 * np.pi * X[:, 0]) + 0.5 * (X[:, 1] - 0.5)**2

# Fit GAM with automatic smooth setup
gam = GAM()
result = gam.fit(X, y, k=[10, 15])  # That's it!

print(f"Lambda values: {result['lambda']}")
print(f"Deviance: {result['deviance']}")

# Make predictions
X_test = np.random.uniform(0, 1, (100, 2))
predictions = gam.predict(X_test)

Performance: 1.5x - 65x faster than R's mgcv (problem-dependent)

See API_SIMPLIFICATION.md for more details on the simplified Python API.

Rust

use mgcv_rust::{GAM, Family, SmoothTerm, OptimizationMethod};
use ndarray::{Array1, Array2};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    // Generate data: y = sin(2πx) + noise
    let n = 100;
    let x_data: Vec<f64> = (0..n).map(|i| i as f64 / n as f64).collect();
    let y_data: Vec<f64> = x_data
        .iter()
        .map(|&xi| (2.0 * std::f64::consts::PI * xi).sin() + noise())
        .collect();

    let x = Array1::from_vec(x_data);
    let y = Array1::from_vec(y_data);
    let x_matrix = x.into_shape((n, 1))?;

    // Create GAM with cubic spline smooth
    let mut gam = GAM::new(Family::Gaussian);
    let smooth = SmoothTerm::cubic_spline("x".to_string(), 20, 0.0, 1.0)?;
    gam.add_smooth(smooth);

    // Fit with REML smoothing parameter selection
    gam.fit(
        &x_matrix,
        &y,
        OptimizationMethod::REML,
        5,    // max outer iterations
        50,   // max inner iterations (PiRLS)
        1e-4  // convergence tolerance
    )?;

    // Make predictions
    let predictions = gam.predict(&x_test)?;

    Ok(())
}

Architecture

Core Components

  1. basis.rs: Basis function implementations

    • CubicSpline: Cubic B-spline basis with configurable knots
    • ThinPlateSpline: Radial basis functions for smooth regression
  2. penalty.rs: Penalty matrix construction

    • Second derivative penalties for smoothness
    • Supports multiple penalty types per basis
  3. pirls.rs: Penalized IRLS fitting algorithm

    • Implements PiRLS for GLMs with penalties
    • Supports all standard GLM families
    • Automatic weight computation and convergence checking
  4. reml.rs: Smoothing parameter selection

    • REML criterion for optimal smoothing
    • GCV criterion as alternative
    • Log-determinant computations
  5. smooth.rs: Smoothing parameter optimization

    • Coordinate descent optimization
    • Grid search for initialization
    • Works in log-space for numerical stability
  6. gam.rs: Main GAM model interface

    • Combines all components
    • Handles multiple smooth terms
    • Outer loop for lambda optimization
  7. linalg.rs: Linear algebra operations

    • Gaussian elimination with partial pivoting
    • Matrix inversion via Gauss-Jordan
    • Determinant computation via LU decomposition

Mathematical Background

GAM Model

g(E[Y]) = β₀ + f₁(x₁) + f₂(x₂) + ... + fₚ(xₚ)

Where:

  • g() is the link function
  • fᵢ() are smooth functions represented by basis expansions
  • Each smooth is penalized by λᵢ ∫ (f''ᵢ(x))² dx

PiRLS Algorithm

  1. Initialize: η = g(y)
  2. Until convergence:
    • Compute μ = g⁻¹(η)
    • Compute weights: w = (g'(μ))² / V(μ)
    • Compute working response: z = η + (y - μ) / g'(μ)
    • Solve: β = (X'WX + λS)⁻¹ X'Wz
    • Update: η = Xβ

REML Criterion

REML(λ) = n·log(RSS) + log|X'WX + λS| - log|S|

Minimized with respect to λ to select optimal smoothing parameters.

Examples

See examples/simple_gam.rs for a complete working example:

cargo run --example simple_gam --release

Project Structure

├── src/                    # Core Rust library code
├── examples/               # Rust usage examples
├── benches/               # Rust benchmarks
├── tests/                 # Rust tests
├── scripts/               # Testing and benchmarking scripts
│   ├── python/            # Python scripts
│   │   ├── tests/         # Python test scripts
│   │   └── benchmarks/    # Python benchmark scripts
│   └── r/                 # R scripts
│       ├── tests/         # R test scripts
│       └── benchmarks/    # R benchmark scripts
├── docs/                  # Documentation and analysis
└── test_data/            # Test data and results

Testing

Run the Rust test suite:

cargo test

All 20 tests should pass, covering:

  • Basis function evaluation
  • Penalty matrix construction
  • Linear algebra operations
  • REML/GCV criteria
  • PiRLS convergence
  • Full GAM fitting pipeline

Additional tests and benchmarks are available in the scripts/ directory.

Implementation Notes

  • TDD Approach: Every feature was implemented with tests first
  • No External Dependencies: Custom linear algebra to avoid BLAS/LAPACK issues
  • Numerical Stability: Operations performed in log-space where appropriate
  • Extensible Design: Easy to add new basis types, families, or criteria

Limitations & Future Work

  • Smoothing parameter optimization could be improved with better algorithms (e.g., Newton-Raphson)
  • Eigendecomposition for handling penalty null spaces more rigorously
  • Confidence intervals and standard errors
  • Model diagnostics and residual analysis
  • Tensor product smooths for multivariate terms
  • Parallel processing for large datasets

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. JRSS-B, 73(1), 3-36.

License

MIT License - see LICENSE file for details

Author

Implemented as a Rust port of R's mgcv package core functionality.

Update: REML Implementation Fixed! ✅

You were absolutely right - the REML implementation had bugs that caused it to always select λ ≈ 0.

What Was Wrong

  1. Singular Penalty Handling: REML was incorrectly handling rank-deficient penalty matrices, setting log|S| = 0 which broke the criterion
  2. Lambda Passing: Optimization was passing λ = 1.0 with pre-multiplied penalties, confusing the rank(S)*log(λ) term
  3. Insufficient Data: Examples used n=30 with p=15 (ratio 2:1), which is too small for REML/GCV

What Was Fixed

  1. REML Criterion: Now correctly uses log|λS| = rank(S)*log(λ) + constant
  2. Optimization: Passes actual λ values to criterion functions
  3. Data Size: Increased to n=300 for proper n/p ratio (20:1)
  4. REML Search: Uses fine grid search (gradient descent had issues)

Current Performance (n=300)

GCV:  λ = 0.067, Test RMSE = 0.480  ✅
REML: λ = 0.058, Test RMSE = 0.480  ✅

Both methods now select nearly optimal smoothing parameters!

See IMPLEMENTATION_SUMMARY.md for complete details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.9.0.tar.gz (9.7 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.9.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (835.1 kB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.9.0-cp314-cp314-win_amd64.whl (3.4 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.9.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (837.4 kB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.9.0-cp314-cp314-macosx_11_0_arm64.whl (673.8 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.9.0-cp313-cp313-win_amd64.whl (3.4 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.9.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (836.2 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.9.0-cp313-cp313-macosx_11_0_arm64.whl (672.2 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.9.0-cp312-cp312-win_amd64.whl (3.4 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (836.6 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.9.0-cp312-cp312-macosx_11_0_arm64.whl (673.2 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.9.0-cp311-cp311-win_amd64.whl (3.4 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (835.9 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.9.0-cp311-cp311-macosx_11_0_arm64.whl (673.6 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.9.0-cp310-cp310-win_amd64.whl (3.4 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (837.5 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (839.5 kB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

mgcv_rust-0.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (838.9 kB view details)

Uploaded CPython 3.8manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.9.0.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.9.0.tar.gz
  • Upload date:
  • Size: 9.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.1

File hashes

Hashes for mgcv_rust-0.9.0.tar.gz
Algorithm Hash digest
SHA256 473d3e9f73fe2580f57f5a10e9a7ed29560ae99a89b098f8eaa47ea3c894ad7a
MD5 bd9ef5312f8c4a7914f9e20e9f7ded73
BLAKE2b-256 a84895daccd491a570c42325cafe5bb4d8b1fedb01b570492d6ebc1827062d8a

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.9.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.9.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 d6bb2ac06376e20feef47426ea757c77c62a223a05dbaf01b22ea953753316ee
MD5 2b177cbc7c0de1f66a25a7aab4559583
BLAKE2b-256 f5a847bf9240ae3f87f2d4484dd3b7fb0c472a9555c2a10a0da1ccdde77a79c6

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.9.0-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.9.0-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 71b2b7f3ad76d9d3fa8b7bd4a628c1cffa8eef21660e40fda564ad21f7c9b8e1
MD5 1a77f443a2fabef51e9132e34b826acb
BLAKE2b-256 17339daed6bb0d7fa0145dd79115858747acec011074880fa6394fd69bc0ad9c

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.9.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.9.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 df64d36c686401f7ee807b85735d29371530b773295cf85eb81b836690848fa3
MD5 570a467b467a898dce89dbfab728e88a
BLAKE2b-256 b255a520169710f945f7ba98071f548f21f4fb7954d2e782092fec87c329e5ab

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.9.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.9.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 f945e17fee7281681dc41e2567ba0ccb85ea3d9deee36fce5ba139dddebefb91
MD5 23bdc257a0da9d33ef46bb1e0888df9c
BLAKE2b-256 e54a82becad71dfb6f1c7a37049827d729eba54ee6790c53816f8d4d2d8fa3e1

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.9.0-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.9.0-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 760e8b3e973273110924772fcee3d43b3167bc7efdf34dfa02981fa9a2f15b2e
MD5 baf99e170683a55003b16fef66dfbceb
BLAKE2b-256 ebdbf5e897642c7f0bc688e47e9797624d4546ad95316eb935d38fa906c0c56b

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.9.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.9.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 51cdbf3c7088b6eb478a9e4a626d59aea593d1045008f1e5d546c6c4d7af697f
MD5 bf24f0568facf531d75261e1cd116387
BLAKE2b-256 825ad0f2ee43612d0e72862bd42191c99254fe0b007be9820ccd7af996790314

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.9.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.9.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 dddc412521492cb0ffc0eaf55fd5263f72970175d9f7a3650fd18d0b9b7e3500
MD5 c1a9ad74f1c8ede275d1093da40908ac
BLAKE2b-256 6d0bc84dd5ef2524fef3768689712e61a780bea343a89fb7349a544fba19209f

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.9.0-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.9.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 aaf809bb33f6cb0d82e90a4d3e5e8ed8f011e5e5feed26beeff53b622206e635
MD5 7fcd7d3fdb79f3d40df65e93d8f3f362
BLAKE2b-256 bc67c2142f5825830e446be22711b26416f623aa2373e7c9c53c043ecc7e49b9

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 2b7854bb151302e9ec04896a287955185d62438567096edac0fef7a0b8b571ef
MD5 fdcf2d015d38ff0c3a6a730de5e8f07b
BLAKE2b-256 bf8ee9fbe41b9163c0c0685634c542c9859299fab4d0407512796224886790bc

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.9.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.9.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 082b37a4ac814a8a8723fd22de1476cfceee9936718f69d88cced066506de6bd
MD5 ffd704fdd9c45bb99871f0e7691df6cf
BLAKE2b-256 fa3b4f00ac5db72d1c11479604c2309632c37a3763f1d4ef7fe7f570797b6764

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.9.0-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.9.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 cc06f7b6a535cb9ef10e04f9c6d0479c24386694581a0f35f92d985be1bfc210
MD5 e5c53b577c3cb1a737756d45d1922f58
BLAKE2b-256 9a0c8e084aed68912babcaa7ede4b1db0db30deaf6f25c5bcc9faeefa4609da3

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 99c604f7ba69320c8cc53843e571ac09b0148fe5863f7409476f51f5ca0955ae
MD5 11a9d2626c84c5d8b118806eb5c235b5
BLAKE2b-256 b41a8602969f779758bfcbac8686f6751c84fc8d19d49c0eef64c08d45e96cf0

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.9.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.9.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 470d8fc6003cdfda110fc0b9151aa00b899c3b941251cf09a7467407c58213df
MD5 f52c7e33c6f4a5800f9a4e22b55b8409
BLAKE2b-256 9b5036ac2a19e01a260a8e9341e7167f55e024a7741755fb2ab569183ef2f2c8

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.9.0-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.9.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 b097011ee4b779a5eec7c644de7a116d9f99dcdedf93566d769025f18f90109c
MD5 710034f09e42b0d5ad6a4e22466f6763
BLAKE2b-256 ff0325f30e1908583e694fbbb86f358fac0f725368cfc3bf6ffa9fa37215c772

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 60f9030efe8e384384e5c67a793cb73df57c5974e528b99d1df1bc9b85458e6c
MD5 45a196c53c65776419a388d23db882bf
BLAKE2b-256 ca7af2d9231f7448ca2d0979ae74494bfdff2dcd99b440a17416f2c4dc484a67

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 3f4181e4f22bcb2091adf2e3b3722c5d03126a9ea6f4b9b3d90fc0a01cba0805
MD5 37b7a3f11700d587de0e7d5f51399ce4
BLAKE2b-256 7508089cd9dd53aa2a6b5c3e1c21db3a8ad0e1f04f100f03be33196493559893

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 a2dbbef483b581c7798e780ccf02e048de8819221e120c1631ddf60b739f2c11
MD5 60fb7817833bc9abb0369590cf950f41
BLAKE2b-256 a7d09b00d9a08d9ddf3b17b725b4c27e43f56ed0fe9e94bbb00128e72e35fc4f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page