Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust: Generalized Additive Models in Rust

A Rust implementation of Generalized Additive Models (GAMs) with automatic smoothing parameter selection using REML (Restricted Maximum Likelihood) and the PiRLS (Penalized Iteratively Reweighted Least Squares) algorithm, inspired by R's mgcv package.

Features

  • Multiple Distribution Families: Gaussian, Binomial, Poisson, and Gamma
  • Flexible Basis Functions:
    • Cubic B-splines with natural boundary conditions
    • Thin plate splines for smooth multivariate regression
  • Automatic Smoothing:
    • REML (Restricted Maximum Likelihood) criterion
    • GCV (Generalized Cross-Validation) criterion
  • PiRLS Algorithm: Efficient fitting via Penalized Iteratively Reweighted Least Squares
  • Pure Rust: No external BLAS/LAPACK dependencies
  • Test-Driven Development: Comprehensive test suite with 20+ passing tests

Installation

Add to your Cargo.toml:

[dependencies]
mgcv_rust = { path = "." }
ndarray = "0.16"

Quick Start

Python (Recommended)

import numpy as np
from mgcv_rust import GAM

# Generate data: y = sin(2πx) + noise
X = np.random.uniform(0, 1, (500, 2))
y = np.sin(2 * np.pi * X[:, 0]) + 0.5 * (X[:, 1] - 0.5)**2

# Fit GAM with automatic smooth setup
gam = GAM()
result = gam.fit(X, y, k=[10, 15])  # That's it!

print(f"Lambda values: {result['lambda']}")
print(f"Deviance: {result['deviance']}")

# Make predictions
X_test = np.random.uniform(0, 1, (100, 2))
predictions = gam.predict(X_test)

Performance: 1.5x - 65x faster than R's mgcv (problem-dependent)

See API_SIMPLIFICATION.md for more details on the simplified Python API.

Rust

use mgcv_rust::{GAM, Family, SmoothTerm, OptimizationMethod};
use ndarray::{Array1, Array2};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    // Generate data: y = sin(2πx) + noise
    let n = 100;
    let x_data: Vec<f64> = (0..n).map(|i| i as f64 / n as f64).collect();
    let y_data: Vec<f64> = x_data
        .iter()
        .map(|&xi| (2.0 * std::f64::consts::PI * xi).sin() + noise())
        .collect();

    let x = Array1::from_vec(x_data);
    let y = Array1::from_vec(y_data);
    let x_matrix = x.into_shape((n, 1))?;

    // Create GAM with cubic spline smooth
    let mut gam = GAM::new(Family::Gaussian);
    let smooth = SmoothTerm::cubic_spline("x".to_string(), 20, 0.0, 1.0)?;
    gam.add_smooth(smooth);

    // Fit with REML smoothing parameter selection
    gam.fit(
        &x_matrix,
        &y,
        OptimizationMethod::REML,
        5,    // max outer iterations
        50,   // max inner iterations (PiRLS)
        1e-4  // convergence tolerance
    )?;

    // Make predictions
    let predictions = gam.predict(&x_test)?;

    Ok(())
}

Architecture

Core Components

  1. basis.rs: Basis function implementations

    • CubicSpline: Cubic B-spline basis with configurable knots
    • ThinPlateSpline: Radial basis functions for smooth regression
  2. penalty.rs: Penalty matrix construction

    • Second derivative penalties for smoothness
    • Supports multiple penalty types per basis
  3. pirls.rs: Penalized IRLS fitting algorithm

    • Implements PiRLS for GLMs with penalties
    • Supports all standard GLM families
    • Automatic weight computation and convergence checking
  4. reml.rs: Smoothing parameter selection

    • REML criterion for optimal smoothing
    • GCV criterion as alternative
    • Log-determinant computations
  5. smooth.rs: Smoothing parameter optimization

    • Coordinate descent optimization
    • Grid search for initialization
    • Works in log-space for numerical stability
  6. gam.rs: Main GAM model interface

    • Combines all components
    • Handles multiple smooth terms
    • Outer loop for lambda optimization
  7. linalg.rs: Linear algebra operations

    • Gaussian elimination with partial pivoting
    • Matrix inversion via Gauss-Jordan
    • Determinant computation via LU decomposition

Mathematical Background

GAM Model

g(E[Y]) = β₀ + f₁(x₁) + f₂(x₂) + ... + fₚ(xₚ)

Where:

  • g() is the link function
  • fᵢ() are smooth functions represented by basis expansions
  • Each smooth is penalized by λᵢ ∫ (f''ᵢ(x))² dx

PiRLS Algorithm

  1. Initialize: η = g(y)
  2. Until convergence:
    • Compute μ = g⁻¹(η)
    • Compute weights: w = (g'(μ))² / V(μ)
    • Compute working response: z = η + (y - μ) / g'(μ)
    • Solve: β = (X'WX + λS)⁻¹ X'Wz
    • Update: η = Xβ

REML Criterion

REML(λ) = n·log(RSS) + log|X'WX + λS| - log|S|

Minimized with respect to λ to select optimal smoothing parameters.

Examples

See examples/simple_gam.rs for a complete working example:

cargo run --example simple_gam --release

Project Structure

├── src/                    # Core Rust library code
├── examples/               # Rust usage examples
├── benches/               # Rust benchmarks
├── tests/                 # Rust tests
├── scripts/               # Testing and benchmarking scripts
│   ├── python/            # Python scripts
│   │   ├── tests/         # Python test scripts
│   │   └── benchmarks/    # Python benchmark scripts
│   └── r/                 # R scripts
│       ├── tests/         # R test scripts
│       └── benchmarks/    # R benchmark scripts
├── docs/                  # Documentation and analysis
└── test_data/            # Test data and results

Testing

Run the Rust test suite:

cargo test

All 20 tests should pass, covering:

  • Basis function evaluation
  • Penalty matrix construction
  • Linear algebra operations
  • REML/GCV criteria
  • PiRLS convergence
  • Full GAM fitting pipeline

Additional tests and benchmarks are available in the scripts/ directory.

Implementation Notes

  • TDD Approach: Every feature was implemented with tests first
  • No External Dependencies: Custom linear algebra to avoid BLAS/LAPACK issues
  • Numerical Stability: Operations performed in log-space where appropriate
  • Extensible Design: Easy to add new basis types, families, or criteria

Limitations & Future Work

  • Smoothing parameter optimization could be improved with better algorithms (e.g., Newton-Raphson)
  • Eigendecomposition for handling penalty null spaces more rigorously
  • Confidence intervals and standard errors
  • Model diagnostics and residual analysis
  • Tensor product smooths for multivariate terms
  • Parallel processing for large datasets

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. JRSS-B, 73(1), 3-36.

License

MIT License - see LICENSE file for details

Author

Implemented as a Rust port of R's mgcv package core functionality.

Update: REML Implementation Fixed! ✅

You were absolutely right - the REML implementation had bugs that caused it to always select λ ≈ 0.

What Was Wrong

  1. Singular Penalty Handling: REML was incorrectly handling rank-deficient penalty matrices, setting log|S| = 0 which broke the criterion
  2. Lambda Passing: Optimization was passing λ = 1.0 with pre-multiplied penalties, confusing the rank(S)*log(λ) term
  3. Insufficient Data: Examples used n=30 with p=15 (ratio 2:1), which is too small for REML/GCV

What Was Fixed

  1. REML Criterion: Now correctly uses log|λS| = rank(S)*log(λ) + constant
  2. Optimization: Passes actual λ values to criterion functions
  3. Data Size: Increased to n=300 for proper n/p ratio (20:1)
  4. REML Search: Uses fine grid search (gradient descent had issues)

Current Performance (n=300)

GCV:  λ = 0.067, Test RMSE = 0.480  ✅
REML: λ = 0.058, Test RMSE = 0.480  ✅

Both methods now select nearly optimal smoothing parameters!

See IMPLEMENTATION_SUMMARY.md for complete details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.1.12.tar.gz (1.6 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.1.12-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (687.8 kB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.1.12-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (687.7 kB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.1.12-cp314-cp314-macosx_11_0_arm64.whl (539.4 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.1.12-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (688.5 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.1.12-cp313-cp313-macosx_11_0_arm64.whl (540.4 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.1.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (688.5 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.1.12-cp312-cp312-macosx_11_0_arm64.whl (540.6 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.1.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (688.1 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.1.12-cp311-cp311-macosx_11_0_arm64.whl (542.8 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.1.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (690.0 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.1.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (692.4 kB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

mgcv_rust-0.1.12-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (692.2 kB view details)

Uploaded CPython 3.8manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.1.12.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.1.12.tar.gz
  • Upload date:
  • Size: 1.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.10.2

File hashes

Hashes for mgcv_rust-0.1.12.tar.gz
Algorithm Hash digest
SHA256 424b6cc57a93e6a32ddd22a25ec216f2cf046f7de117d431f125fccf039b0233
MD5 bc14793c90203e1b13ece7eb13aa5661
BLAKE2b-256 6131f51fa847256b6aa4e487dea12e9cc091f14a72ad18ec692d9b23c3957785

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.1.12-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.1.12-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 66cb06a57665c3ba85e500b06121ad879714cef67af5f0c37dfc253e50491daa
MD5 570c4b90b68975fb91158119c277d53d
BLAKE2b-256 3d71709f9c8adfb178e3fa3ef54c2d7b774b4475acfcec4f8f706745a43f8ab8

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.1.12-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.1.12-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 58a42fcfab9a167271b9e2baae7e12c5159a1704e36074305d4cbd7a07baeeeb
MD5 5b46fe8833697abe7aca055e93b459eb
BLAKE2b-256 dd9d4f640069b2b0be25f224136810b28da0f0cb9b046c8a307bcfc61cb8ac29

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.1.12-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.1.12-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 b83eb604f0d34c4d3cdb7d818309772968b0230978fde06631a0c89cdc42f299
MD5 08b3c07af0184814a18d94680bf2c9ba
BLAKE2b-256 0bdf85a108b2804c8299f25b5ec6aee252b5e56735e8405220a1326ffedcbc38

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.1.12-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.1.12-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 0b9560b755c0299e007152dbd6f25ec49d03e6ce5ab8f74c836d14e841da1e38
MD5 1a3af1290f9660de187e390cb2988c68
BLAKE2b-256 db20176164679321e22bf917dda7197248b09e75813c0135fe130733b3bfd3c4

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.1.12-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.1.12-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 c52ae10b26b0452c890bb916058fe87d4c9daf92f517a4c09de417abc582e5b9
MD5 4643d6bd464cc4434d420a581fc36b59
BLAKE2b-256 2941103dcf0a3aae135ac962c44c55b03f9873a3246016fa6579f7a388e2c08f

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.1.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.1.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 e4bf1f0f5aeba4a368aff3b0299d4a763166f2809db3a2af651c4130220fd514
MD5 35b951f35941e5094eb7e398ec06568a
BLAKE2b-256 f799f001088251360d81df6f262581150d1fcf0b80e37834b770788e57082f64

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.1.12-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.1.12-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 1fe80ffa8eacccd5dbc13bed71a6cf7e8988705e27a7cc8f849a58133fe257d5
MD5 d3cb07af73407ff72c7fd73408c2250c
BLAKE2b-256 fe4fb3e38d73db81b5dab7be8312ffe908f9520428070e0bf6030001cf35811a

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.1.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.1.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 5f5fed411b60614f8343640429c9a622f07c59e98378bb88f2e748fd93cc1c03
MD5 5c7408dade34eb4f0fdc32a05736d5b2
BLAKE2b-256 f6d26bf6322384ca030f7a5f899b14273e356b70d863d1173ab33b83178ac175

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.1.12-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.1.12-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 cd85c7bf4afa8d815b503a5c53aa88e3a131db40973f03185f6e5317dbcc2894
MD5 be6171a91bdc4c2c5fdc617505c865e3
BLAKE2b-256 f67e92bfa913c0e23839f25c55ebfb66a73ca2c92f1f378fc5974f8f8b81de5f

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.1.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.1.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 5b681ac51b1d2d9a383c17cf3515c6f4c5f593158096553449f9310de905171b
MD5 a4c8baeed01f9388dae8fa96f850b354
BLAKE2b-256 aab7fd738e2dc897d14920eb729cfbf27fb3676870eb6382ea843d2297620a46

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.1.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.1.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 1ec3e4c74ef4110adde97771cacf7025b868a50b5cdd79285276db221331c681
MD5 97e1aeae0a344b925b34b487460a9a0c
BLAKE2b-256 bd72ed37d71cbab8154c1adbaab9d538e5d3a8e9f7d5c2db2a3a72cd449375a8

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.1.12-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.1.12-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 e8afff3bba051c3a68c78cf55a5abd60c87b586f979471d15080fe9a8506cec2
MD5 8e70107d9d06a7dcfe537a31ac718179
BLAKE2b-256 a61b7e9128ef05a53c12fb832afbd4d11315f64e84617d4860cdf98e117fbfde

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page