Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust: Generalized Additive Models in Rust

A Rust implementation of Generalized Additive Models (GAMs) with automatic smoothing parameter selection using REML (Restricted Maximum Likelihood) and the PiRLS (Penalized Iteratively Reweighted Least Squares) algorithm, inspired by R's mgcv package.

Features

  • Multiple Distribution Families: Gaussian, Binomial, Poisson, and Gamma
  • Flexible Basis Functions:
    • Cubic B-splines with natural boundary conditions
    • Thin plate splines for smooth multivariate regression
  • Automatic Smoothing:
    • REML (Restricted Maximum Likelihood) criterion
    • GCV (Generalized Cross-Validation) criterion
  • PiRLS Algorithm: Efficient fitting via Penalized Iteratively Reweighted Least Squares
  • Pure Rust: No external BLAS/LAPACK dependencies
  • Test-Driven Development: Comprehensive test suite with 20+ passing tests

Installation

Add to your Cargo.toml:

[dependencies]
mgcv_rust = { path = "." }
ndarray = "0.16"

Quick Start

Python — high-level wrapper (recommended)

GAMFitter is a drop-in replacement for mgcv::gam / rpy2-based fitters. Accepts numpy / pandas / polars inputs, named predictors, per-term k overrides, posterior sampling, confidence intervals, and a serialize() method matching the schema consumed by downstream GamPredictor-style code.

import numpy as np, pandas as pd
from mgcv_rust import GAMFitter

df = pd.DataFrame({"days_ago": ..., "quality": ...})
y  = ...

gam = GAMFitter(
    predictors=("days_ago", "quality"),
    k_default=6,
    term_k_mapping={"days_ago": 25, "quality": 12},
    family="gaussian", link="identity",
)
gam.fit(df, y)
preds = gam.predict(df)
lo, hi = gam.predict_ci(df, alpha=0.05, n_samples=1000)
serialized = gam.serialize()

Python — low-level Rust core

import numpy as np
from mgcv_rust import GAM

# Generate data: y = sin(2πx) + noise
X = np.random.uniform(0, 1, (500, 2))
y = np.sin(2 * np.pi * X[:, 0]) + 0.5 * (X[:, 1] - 0.5)**2

# Fit GAM with automatic smooth setup
gam = GAM()
result = gam.fit(X, y, k=[10, 15])  # That's it!

print(f"Lambda values: {result['lambda']}")
print(f"Deviance: {result['deviance']}")

# Make predictions
X_test = np.random.uniform(0, 1, (100, 2))
predictions = gam.predict(X_test)

Performance: 1.5x - 65x faster than R's mgcv (problem-dependent)

See API_SIMPLIFICATION.md for more details on the simplified Python API.

Rust

use mgcv_rust::{GAM, Family, SmoothTerm, OptimizationMethod};
use ndarray::{Array1, Array2};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    // Generate data: y = sin(2πx) + noise
    let n = 100;
    let x_data: Vec<f64> = (0..n).map(|i| i as f64 / n as f64).collect();
    let y_data: Vec<f64> = x_data
        .iter()
        .map(|&xi| (2.0 * std::f64::consts::PI * xi).sin() + noise())
        .collect();

    let x = Array1::from_vec(x_data);
    let y = Array1::from_vec(y_data);
    let x_matrix = x.into_shape((n, 1))?;

    // Create GAM with cubic spline smooth
    let mut gam = GAM::new(Family::Gaussian);
    let smooth = SmoothTerm::cubic_spline("x".to_string(), 20, 0.0, 1.0)?;
    gam.add_smooth(smooth);

    // Fit with REML smoothing parameter selection
    gam.fit(
        &x_matrix,
        &y,
        OptimizationMethod::REML,
        5,    // max outer iterations
        50,   // max inner iterations (PiRLS)
        1e-4  // convergence tolerance
    )?;

    // Make predictions
    let predictions = gam.predict(&x_test)?;

    Ok(())
}

Architecture

Core Components

  1. basis.rs: Basis function implementations

    • CubicSpline: Cubic B-spline basis with configurable knots
    • ThinPlateSpline: Radial basis functions for smooth regression
  2. penalty.rs: Penalty matrix construction

    • Second derivative penalties for smoothness
    • Supports multiple penalty types per basis
  3. pirls.rs: Penalized IRLS fitting algorithm

    • Implements PiRLS for GLMs with penalties
    • Supports all standard GLM families
    • Automatic weight computation and convergence checking
  4. reml.rs: Smoothing parameter selection

    • REML criterion for optimal smoothing
    • GCV criterion as alternative
    • Log-determinant computations
  5. smooth.rs: Smoothing parameter optimization

    • Coordinate descent optimization
    • Grid search for initialization
    • Works in log-space for numerical stability
  6. gam.rs: Main GAM model interface

    • Combines all components
    • Handles multiple smooth terms
    • Outer loop for lambda optimization
  7. linalg.rs: Linear algebra operations

    • Gaussian elimination with partial pivoting
    • Matrix inversion via Gauss-Jordan
    • Determinant computation via LU decomposition

Mathematical Background

GAM Model

g(E[Y]) = β₀ + f₁(x₁) + f₂(x₂) + ... + fₚ(xₚ)

Where:

  • g() is the link function
  • fᵢ() are smooth functions represented by basis expansions
  • Each smooth is penalized by λᵢ ∫ (f''ᵢ(x))² dx

PiRLS Algorithm

  1. Initialize: η = g(y)
  2. Until convergence:
    • Compute μ = g⁻¹(η)
    • Compute weights: w = (g'(μ))² / V(μ)
    • Compute working response: z = η + (y - μ) / g'(μ)
    • Solve: β = (X'WX + λS)⁻¹ X'Wz
    • Update: η = Xβ

REML Criterion

REML(λ) = n·log(RSS) + log|X'WX + λS| - log|S|

Minimized with respect to λ to select optimal smoothing parameters.

Examples

See examples/simple_gam.rs for a complete working example:

cargo run --example simple_gam --release

Project Structure

├── src/                    # Core Rust library code
├── examples/               # Rust usage examples
├── benches/               # Rust benchmarks
├── tests/                 # Rust tests
├── scripts/               # Testing and benchmarking scripts
│   ├── python/            # Python scripts
│   │   ├── tests/         # Python test scripts
│   │   └── benchmarks/    # Python benchmark scripts
│   └── r/                 # R scripts
│       ├── tests/         # R test scripts
│       └── benchmarks/    # R benchmark scripts
├── docs/                  # Documentation and analysis
└── test_data/            # Test data and results

Testing

Run the Rust test suite:

cargo test

All 20 tests should pass, covering:

  • Basis function evaluation
  • Penalty matrix construction
  • Linear algebra operations
  • REML/GCV criteria
  • PiRLS convergence
  • Full GAM fitting pipeline

Additional tests and benchmarks are available in the scripts/ directory.

Implementation Notes

  • TDD Approach: Every feature was implemented with tests first
  • No External Dependencies: Custom linear algebra to avoid BLAS/LAPACK issues
  • Numerical Stability: Operations performed in log-space where appropriate
  • Extensible Design: Easy to add new basis types, families, or criteria

Limitations & Future Work

  • Smoothing parameter optimization could be improved with better algorithms (e.g., Newton-Raphson)
  • Eigendecomposition for handling penalty null spaces more rigorously
  • Confidence intervals and standard errors
  • Model diagnostics and residual analysis
  • Tensor product smooths for multivariate terms
  • Parallel processing for large datasets

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. JRSS-B, 73(1), 3-36.

License

MIT License - see LICENSE file for details

Author

Implemented as a Rust port of R's mgcv package core functionality.

Update: REML Implementation Fixed! ✅

You were absolutely right - the REML implementation had bugs that caused it to always select λ ≈ 0.

What Was Wrong

  1. Singular Penalty Handling: REML was incorrectly handling rank-deficient penalty matrices, setting log|S| = 0 which broke the criterion
  2. Lambda Passing: Optimization was passing λ = 1.0 with pre-multiplied penalties, confusing the rank(S)*log(λ) term
  3. Insufficient Data: Examples used n=30 with p=15 (ratio 2:1), which is too small for REML/GCV

What Was Fixed

  1. REML Criterion: Now correctly uses log|λS| = rank(S)*log(λ) + constant
  2. Optimization: Passes actual λ values to criterion functions
  3. Data Size: Increased to n=300 for proper n/p ratio (20:1)
  4. REML Search: Uses fine grid search (gradient descent had issues)

Current Performance (n=300)

GCV:  λ = 0.067, Test RMSE = 0.480  ✅
REML: λ = 0.058, Test RMSE = 0.480  ✅

Both methods now select nearly optimal smoothing parameters!

See IMPLEMENTATION_SUMMARY.md for complete details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.7.0.tar.gz (9.7 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.7.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (822.6 kB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.7.0-cp314-cp314-win_amd64.whl (3.4 MB view details)

Uploaded CPython 3.14Windows x86-64

mgcv_rust-0.7.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (825.0 kB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.7.0-cp314-cp314-macosx_11_0_arm64.whl (663.3 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.7.0-cp313-cp313-win_amd64.whl (3.4 MB view details)

Uploaded CPython 3.13Windows x86-64

mgcv_rust-0.7.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (824.0 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.7.0-cp313-cp313-macosx_11_0_arm64.whl (661.8 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.7.0-cp312-cp312-win_amd64.whl (3.4 MB view details)

Uploaded CPython 3.12Windows x86-64

mgcv_rust-0.7.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (824.2 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.7.0-cp312-cp312-macosx_11_0_arm64.whl (662.2 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.7.0-cp311-cp311-win_amd64.whl (3.4 MB view details)

Uploaded CPython 3.11Windows x86-64

mgcv_rust-0.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (823.3 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.7.0-cp311-cp311-macosx_11_0_arm64.whl (662.7 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.7.0-cp310-cp310-win_amd64.whl (3.4 MB view details)

Uploaded CPython 3.10Windows x86-64

mgcv_rust-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (825.4 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (827.3 kB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

mgcv_rust-0.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (826.7 kB view details)

Uploaded CPython 3.8manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.7.0.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.7.0.tar.gz
  • Upload date:
  • Size: 9.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.13.1

File hashes

Hashes for mgcv_rust-0.7.0.tar.gz
Algorithm Hash digest
SHA256 20e58e4b22b91bf632a3ab1384fd1bbb225e7593b6d8ed34debd13c6b9569dc1
MD5 55e9ca78ca15ae439beb6823bd9511d2
BLAKE2b-256 42d9099e74b4dc6bfe47575a0272fdb2b116d9560d4e6a1e3e43abc69d861653

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.7.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.7.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 7b088a0b7b2ec0d63551e745fa1333b222073eb3ae28460b76b1c6008dbda094
MD5 c1451ab4af714f06c85cc1c3b9c26a2c
BLAKE2b-256 fcdcf0d42c49b04154e0a8d71d83a37d04358e7d3a5a7cfce068497efeb81148

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.7.0-cp314-cp314-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.7.0-cp314-cp314-win_amd64.whl
Algorithm Hash digest
SHA256 a49a3d7f822bed12b634f200c6c8405b6836c1d4901501d7cdc83c4c2b126742
MD5 2614121046c92bcc5f031020b32d2a71
BLAKE2b-256 2a5f4920faa266304d9d1006e5a110c91d56bd7e353696988159b5d4743b7b2e

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.7.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.7.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 dfa356843ae2dd335c9a2eca2534ea28a17a533947563e11c82b1c646c5af814
MD5 9a6f0466dc4eff9b89676be93eb5e006
BLAKE2b-256 df18cc78c67f295fde6a1b3e4e1919751725dbc1ccc3da09f9ad682f23d24400

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.7.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.7.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 e5a5a8c54d5db7ed807c11251b219f8ff0b3953fffd083ddfeffb3bc61b74670
MD5 f759cfd839f9e7ae66cb4cac1d4ff3ff
BLAKE2b-256 9b54fca0c5b858f90fe1808196dbf5d57630d36394af9b8dd06b6317f770dada

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.7.0-cp313-cp313-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.7.0-cp313-cp313-win_amd64.whl
Algorithm Hash digest
SHA256 0b5a51bfc89114ed39157178fda22a2e628e793be3018026185b42ab1399ee3d
MD5 8cb0167465eba313ea8c403c70444e28
BLAKE2b-256 4a6a36c005d132a2fba3fc31a2fc7f2c4783a3239944dbf1b34ccd207eddea49

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.7.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.7.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 ad5b3133be1a272f7828b9c288cf22e80dcbb00fb9323c24e904f4cca9387b73
MD5 dabab3c030f7c58c3bf1e00fb892644e
BLAKE2b-256 a0a97cc12c1310fe99f14b1be7c32cb0677bde6748d2a42d2878d9f3f97ae9e8

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.7.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.7.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 6af79e973b85aa0e7ce895a4d248888b8e84b5cc0dbf7eac10cabd12592da067
MD5 211b12e7069ed92e234ac7076971d57f
BLAKE2b-256 4862667ad88209f042b61826e338c06ec5e71dd871c8d08e3bb6a3a7d9197dac

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.7.0-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.7.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 3c49845a8fbe8c920663def17e40514f27dc61aabd507eb9e92a08b324f68984
MD5 0b4b497052e18290707a3ad996cfde07
BLAKE2b-256 239a47c177899f4f07a9b43bae252594f65653493cfe34501797a97f44f02092

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.7.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.7.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 0374ff70d525c251948e285b9aeac102720fe230e5b2b8e2abe4e2dbf3eb260f
MD5 9a2604a62684a7cd8039477703355221
BLAKE2b-256 af182e6e4f34882e740dfd9a9dce6cedf7c8f9ecae5c3e45fb688b7bed2ebf23

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.7.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.7.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 66ca41c379f9b6b60a16dfbd90fe761caff0255fefa34899d3582a784f1254e0
MD5 998796e19594e554de47d176c199629c
BLAKE2b-256 70dcba3fe51c6c41e8f13af5c31227e278b9a9650486071e202bb7a257a0a468

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.7.0-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.7.0-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 e0ec2cc1b56619926d87fb932df5b8196db18bfed0c4eeb915f149d87206597c
MD5 d4dbc42c475dc8bc1253391bda2601a7
BLAKE2b-256 453447b2dc7cd31a4c02fbdffc563963b1900951c426521bdc9ced28796e5321

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 378879b043d458bc9d52b169d1648704afc3a683714c057455cab4a58a7a9272
MD5 b11503124762b35f2f032089dc53a848
BLAKE2b-256 018c5a161b8c8c83e9dba358e859438f603b5322a1b54f89bfa303883bcab905

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.7.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.7.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 c1ec12eba1a6fcc2d1e62264a5bae58ad0c61522c2d371c5851ba8a14f0dfb45
MD5 2f5ae9422c83b854fb977cf69d6749b5
BLAKE2b-256 0bcb421e958c0a48fedb9ccd181f4509823be74773daf2dfacf8be2beb29edc7

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.7.0-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.7.0-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 cf29438911829b0c78165bb4fe821bb261da0442964adeb08d1983bdc0cb52ba
MD5 dd48660731ec11eab84e1ae9193bbb8c
BLAKE2b-256 819aca87ea46719b40d781a5f8eecdf11fb9693c6ceec0dd846d1c7b12ce661f

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 6b6feb7785b4103529ee59282729edc3ac00a278e615f164c39d9c716653f470
MD5 5cd53f952085419ff791d480229fb8ce
BLAKE2b-256 b3c8665148a6648f4a107603b1c48fe99a55224ffc18b4d5110b3e86f5c7a16c

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 56b05a4b99b8b0459030028a0989ff6922454066efa2facd3c4b2f88c35859b0
MD5 03b9b56bd386f831364e19e3ccfc97eb
BLAKE2b-256 d6b10d386f361041e7685927c551585072ed1a537ce4bf2cc2fd2a00b3b87610

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 ee667b49b6cf3d0b95dcd2c9161190b9519ee0706d39c1d1ba8a2987de9a701b
MD5 9d0e3b5dc9e595450a0c991931af4e93
BLAKE2b-256 0920a96bdc3eb0b5aac8def57a0ccb4b28c673cca68f587277d100df1019dc2a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page