Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust: Generalized Additive Models in Rust

A Rust implementation of Generalized Additive Models (GAMs) with automatic smoothing parameter selection using REML (Restricted Maximum Likelihood) and the PiRLS (Penalized Iteratively Reweighted Least Squares) algorithm, inspired by R's mgcv package.

Features

  • Multiple Distribution Families: Gaussian, Binomial, Poisson, and Gamma
  • Flexible Basis Functions:
    • Cubic B-splines with natural boundary conditions
    • Thin plate splines for smooth multivariate regression
  • Automatic Smoothing:
    • REML (Restricted Maximum Likelihood) criterion
    • GCV (Generalized Cross-Validation) criterion
  • PiRLS Algorithm: Efficient fitting via Penalized Iteratively Reweighted Least Squares
  • Pure Rust: No external BLAS/LAPACK dependencies
  • Test-Driven Development: Comprehensive test suite with 20+ passing tests

Installation

Add to your Cargo.toml:

[dependencies]
mgcv_rust = { path = "." }
ndarray = "0.16"

Quick Start

Python (Recommended)

import numpy as np
from mgcv_rust import GAM

# Generate data: y = sin(2πx) + noise
X = np.random.uniform(0, 1, (500, 2))
y = np.sin(2 * np.pi * X[:, 0]) + 0.5 * (X[:, 1] - 0.5)**2

# Fit GAM with automatic smooth setup
gam = GAM()
result = gam.fit(X, y, k=[10, 15])  # That's it!

print(f"Lambda values: {result['lambda']}")
print(f"Deviance: {result['deviance']}")

# Make predictions
X_test = np.random.uniform(0, 1, (100, 2))
predictions = gam.predict(X_test)

Performance: 1.5x - 65x faster than R's mgcv (problem-dependent)

See API_SIMPLIFICATION.md for more details on the simplified Python API.

Rust

use mgcv_rust::{GAM, Family, SmoothTerm, OptimizationMethod};
use ndarray::{Array1, Array2};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    // Generate data: y = sin(2πx) + noise
    let n = 100;
    let x_data: Vec<f64> = (0..n).map(|i| i as f64 / n as f64).collect();
    let y_data: Vec<f64> = x_data
        .iter()
        .map(|&xi| (2.0 * std::f64::consts::PI * xi).sin() + noise())
        .collect();

    let x = Array1::from_vec(x_data);
    let y = Array1::from_vec(y_data);
    let x_matrix = x.into_shape((n, 1))?;

    // Create GAM with cubic spline smooth
    let mut gam = GAM::new(Family::Gaussian);
    let smooth = SmoothTerm::cubic_spline("x".to_string(), 20, 0.0, 1.0)?;
    gam.add_smooth(smooth);

    // Fit with REML smoothing parameter selection
    gam.fit(
        &x_matrix,
        &y,
        OptimizationMethod::REML,
        5,    // max outer iterations
        50,   // max inner iterations (PiRLS)
        1e-4  // convergence tolerance
    )?;

    // Make predictions
    let predictions = gam.predict(&x_test)?;

    Ok(())
}

Architecture

Core Components

  1. basis.rs: Basis function implementations

    • CubicSpline: Cubic B-spline basis with configurable knots
    • ThinPlateSpline: Radial basis functions for smooth regression
  2. penalty.rs: Penalty matrix construction

    • Second derivative penalties for smoothness
    • Supports multiple penalty types per basis
  3. pirls.rs: Penalized IRLS fitting algorithm

    • Implements PiRLS for GLMs with penalties
    • Supports all standard GLM families
    • Automatic weight computation and convergence checking
  4. reml.rs: Smoothing parameter selection

    • REML criterion for optimal smoothing
    • GCV criterion as alternative
    • Log-determinant computations
  5. smooth.rs: Smoothing parameter optimization

    • Coordinate descent optimization
    • Grid search for initialization
    • Works in log-space for numerical stability
  6. gam.rs: Main GAM model interface

    • Combines all components
    • Handles multiple smooth terms
    • Outer loop for lambda optimization
  7. linalg.rs: Linear algebra operations

    • Gaussian elimination with partial pivoting
    • Matrix inversion via Gauss-Jordan
    • Determinant computation via LU decomposition

Mathematical Background

GAM Model

g(E[Y]) = β₀ + f₁(x₁) + f₂(x₂) + ... + fₚ(xₚ)

Where:

  • g() is the link function
  • fᵢ() are smooth functions represented by basis expansions
  • Each smooth is penalized by λᵢ ∫ (f''ᵢ(x))² dx

PiRLS Algorithm

  1. Initialize: η = g(y)
  2. Until convergence:
    • Compute μ = g⁻¹(η)
    • Compute weights: w = (g'(μ))² / V(μ)
    • Compute working response: z = η + (y - μ) / g'(μ)
    • Solve: β = (X'WX + λS)⁻¹ X'Wz
    • Update: η = Xβ

REML Criterion

REML(λ) = n·log(RSS) + log|X'WX + λS| - log|S|

Minimized with respect to λ to select optimal smoothing parameters.

Examples

See examples/simple_gam.rs for a complete working example:

cargo run --example simple_gam --release

Project Structure

├── src/                    # Core Rust library code
├── examples/               # Rust usage examples
├── benches/               # Rust benchmarks
├── tests/                 # Rust tests
├── scripts/               # Testing and benchmarking scripts
│   ├── python/            # Python scripts
│   │   ├── tests/         # Python test scripts
│   │   └── benchmarks/    # Python benchmark scripts
│   └── r/                 # R scripts
│       ├── tests/         # R test scripts
│       └── benchmarks/    # R benchmark scripts
├── docs/                  # Documentation and analysis
└── test_data/            # Test data and results

Testing

Run the Rust test suite:

cargo test

All 20 tests should pass, covering:

  • Basis function evaluation
  • Penalty matrix construction
  • Linear algebra operations
  • REML/GCV criteria
  • PiRLS convergence
  • Full GAM fitting pipeline

Additional tests and benchmarks are available in the scripts/ directory.

Implementation Notes

  • TDD Approach: Every feature was implemented with tests first
  • No External Dependencies: Custom linear algebra to avoid BLAS/LAPACK issues
  • Numerical Stability: Operations performed in log-space where appropriate
  • Extensible Design: Easy to add new basis types, families, or criteria

Limitations & Future Work

  • Smoothing parameter optimization could be improved with better algorithms (e.g., Newton-Raphson)
  • Eigendecomposition for handling penalty null spaces more rigorously
  • Confidence intervals and standard errors
  • Model diagnostics and residual analysis
  • Tensor product smooths for multivariate terms
  • Parallel processing for large datasets

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. JRSS-B, 73(1), 3-36.

License

MIT License - see LICENSE file for details

Author

Implemented as a Rust port of R's mgcv package core functionality.

Update: REML Implementation Fixed! ✅

You were absolutely right - the REML implementation had bugs that caused it to always select λ ≈ 0.

What Was Wrong

  1. Singular Penalty Handling: REML was incorrectly handling rank-deficient penalty matrices, setting log|S| = 0 which broke the criterion
  2. Lambda Passing: Optimization was passing λ = 1.0 with pre-multiplied penalties, confusing the rank(S)*log(λ) term
  3. Insufficient Data: Examples used n=30 with p=15 (ratio 2:1), which is too small for REML/GCV

What Was Fixed

  1. REML Criterion: Now correctly uses log|λS| = rank(S)*log(λ) + constant
  2. Optimization: Passes actual λ values to criterion functions
  3. Data Size: Increased to n=300 for proper n/p ratio (20:1)
  4. REML Search: Uses fine grid search (gradient descent had issues)

Current Performance (n=300)

GCV:  λ = 0.067, Test RMSE = 0.480  ✅
REML: λ = 0.058, Test RMSE = 0.480  ✅

Both methods now select nearly optimal smoothing parameters!

See IMPLEMENTATION_SUMMARY.md for complete details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.2.0.tar.gz (1.6 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.2.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (693.9 kB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.2.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (693.9 kB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.2.0-cp314-cp314-macosx_11_0_arm64.whl (545.1 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (694.2 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.2.0-cp313-cp313-macosx_11_0_arm64.whl (545.8 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (694.2 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.2.0-cp312-cp312-macosx_11_0_arm64.whl (545.9 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (694.0 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.2.0-cp311-cp311-macosx_11_0_arm64.whl (548.2 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (695.5 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (698.6 kB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

mgcv_rust-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (698.4 kB view details)

Uploaded CPython 3.8manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.2.0.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.2.0.tar.gz
  • Upload date:
  • Size: 1.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.10.2

File hashes

Hashes for mgcv_rust-0.2.0.tar.gz
Algorithm Hash digest
SHA256 6a204145520d06c96f8ced8f0bfa38b1cfadbdad00411920341800e55b81e1f7
MD5 b03eb9d6cc01aec62316342edc53d004
BLAKE2b-256 78225a341be42dc991d77ba9c4308db4ac943d5a57c4e732ddbc6d968c0bea4c

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.2.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.2.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 75c9395083941ed88d895e2c7a58971692ecccf3f75189333402f9c6b094e35b
MD5 57c18263b8fe8788af9c50d63db0f3f3
BLAKE2b-256 1e05fdfab31446eb7092a6ec8926b7a737bf11f4551a6726619d4f3ade7f75cf

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.2.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.2.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 e85484be95dc5252562dff7c81000414009e54d2a7fcd391e1b64d080707da6e
MD5 10af103292426b55ee84e76553f5d4e2
BLAKE2b-256 72f64ed1648db3011bd1ab8f2248facdc3e538dfdfa2197728c4d93e6c8cfd5e

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.2.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.2.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 2a93155beb1bc35a076e88026f6b570f4af067b08bea0496e54ed140df148f98
MD5 98b5df02237aebe95303aea6cd2fcf20
BLAKE2b-256 10f5ef042dbf6fcf47d9ef140fedfd2d529a23126705ae3f3a0c3b0d9ed01428

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 b8e2713c92aeae45375a5a0dda0b1f9610409d21d2cf444e40af2342901cb8af
MD5 bafb87eb888305d459bf6c69a29290ec
BLAKE2b-256 4a8c097ae58a5c3fb9ea8d7984b63053915d40d34bf0b1e57296686e7479da0e

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.2.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.2.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 4b55d1b237f65b96da5d662dc12c83213a7a6aae0304ba5e0c731867867bf04f
MD5 bb500494cd06b74b209f8081b21d5d79
BLAKE2b-256 0aa22f756b25bcfbfd3831839c4eaf79b6ec3cec883ff9535ba570b99ec9ac53

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 e61fd77b291a542db897e19ba1c3f87ccb982fbad9ae4822793d3bc1c28cbbb0
MD5 678e0455ff5d010d084b18f011336ede
BLAKE2b-256 25cafc103a84b2bbb50f5579a0562887ebac36fd3cd68f287534869810c4fbbd

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.2.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.2.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 2b0471e3ae85a9c7cf0d0758c2d3e09f2571e5eed4156fdac021ca88a040e3d5
MD5 f0a0ffeb029c6c1057b1ff5f467b239e
BLAKE2b-256 f3e061da7049e28a40038d7676a682c19fb43438a15b6b37ec50e447ca93f438

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 982dfda3e9b43621c2030bdf06157daae384685edb33551e6c831e1e1f64d37b
MD5 359b9875c301fb0cf6af9298f76889dd
BLAKE2b-256 f15792aa1616546abe1b6e52f4a02d057216f52c5d5c088ce2d708850d038368

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.2.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.2.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 a2b24c80bdfe7f8597dc4891f40365af8af5b700be5ed77eaa516e46b71dd5a5
MD5 7bcc4c06fc21e292cc91059716f288f6
BLAKE2b-256 7f5107483abd46c746b1db261ce32727e895619276650c04455657996d858633

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 4fe72d8004818d087e75d836e9668feb4640421cc59595e3cd855fde955cd760
MD5 1fc61bb8c86ac91f68482edd7766fe13
BLAKE2b-256 525c730bbede43543dc47d33eca739b6a710fc1230bc7cb412a72f40ffde4b92

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 7f6eb4d143f3b3cfdfa2c49c9692949f87676c3e7ed73a0facc9bf6dd77d0ca8
MD5 df1eb63b88a70fbc641f6d457f608663
BLAKE2b-256 a718f2a020c2debf1149c36ff2b464ebd487667c17fe56b8dd74552feb148ba8

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 305c9fe96fdf2dd538f0b58b417c253e3daafe965acc5b7729467731aa42cac9
MD5 a8b2edc7ec87129b94856eb0b7deef37
BLAKE2b-256 f26c56952f0b10dcb252a75ac2661b2713437a4de945de46aaae9bfbeda0447d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page