Skip to main content

Rust implementation of Generalized Additive Models with Python bindings

Project description

mgcv_rust: Generalized Additive Models in Rust

A Rust implementation of Generalized Additive Models (GAMs) with automatic smoothing parameter selection using REML (Restricted Maximum Likelihood) and the PiRLS (Penalized Iteratively Reweighted Least Squares) algorithm, inspired by R's mgcv package.

Features

  • Multiple Distribution Families: Gaussian, Binomial, Poisson, and Gamma
  • Flexible Basis Functions:
    • Cubic B-splines with natural boundary conditions
    • Thin plate splines for smooth multivariate regression
  • Automatic Smoothing:
    • REML (Restricted Maximum Likelihood) criterion
    • GCV (Generalized Cross-Validation) criterion
  • PiRLS Algorithm: Efficient fitting via Penalized Iteratively Reweighted Least Squares
  • Pure Rust: No external BLAS/LAPACK dependencies
  • Test-Driven Development: Comprehensive test suite with 20+ passing tests

Installation

Add to your Cargo.toml:

[dependencies]
mgcv_rust = { path = "." }
ndarray = "0.16"

Quick Start

Python (Recommended)

import numpy as np
from mgcv_rust import GAM

# Generate data: y = sin(2πx) + noise
X = np.random.uniform(0, 1, (500, 2))
y = np.sin(2 * np.pi * X[:, 0]) + 0.5 * (X[:, 1] - 0.5)**2

# Fit GAM with automatic smooth setup
gam = GAM()
result = gam.fit(X, y, k=[10, 15])  # That's it!

print(f"Lambda values: {result['lambda']}")
print(f"Deviance: {result['deviance']}")

# Make predictions
X_test = np.random.uniform(0, 1, (100, 2))
predictions = gam.predict(X_test)

Performance: 1.5x - 65x faster than R's mgcv (problem-dependent)

See API_SIMPLIFICATION.md for more details on the simplified Python API.

Rust

use mgcv_rust::{GAM, Family, SmoothTerm, OptimizationMethod};
use ndarray::{Array1, Array2};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    // Generate data: y = sin(2πx) + noise
    let n = 100;
    let x_data: Vec<f64> = (0..n).map(|i| i as f64 / n as f64).collect();
    let y_data: Vec<f64> = x_data
        .iter()
        .map(|&xi| (2.0 * std::f64::consts::PI * xi).sin() + noise())
        .collect();

    let x = Array1::from_vec(x_data);
    let y = Array1::from_vec(y_data);
    let x_matrix = x.into_shape((n, 1))?;

    // Create GAM with cubic spline smooth
    let mut gam = GAM::new(Family::Gaussian);
    let smooth = SmoothTerm::cubic_spline("x".to_string(), 20, 0.0, 1.0)?;
    gam.add_smooth(smooth);

    // Fit with REML smoothing parameter selection
    gam.fit(
        &x_matrix,
        &y,
        OptimizationMethod::REML,
        5,    // max outer iterations
        50,   // max inner iterations (PiRLS)
        1e-4  // convergence tolerance
    )?;

    // Make predictions
    let predictions = gam.predict(&x_test)?;

    Ok(())
}

Architecture

Core Components

  1. basis.rs: Basis function implementations

    • CubicSpline: Cubic B-spline basis with configurable knots
    • ThinPlateSpline: Radial basis functions for smooth regression
  2. penalty.rs: Penalty matrix construction

    • Second derivative penalties for smoothness
    • Supports multiple penalty types per basis
  3. pirls.rs: Penalized IRLS fitting algorithm

    • Implements PiRLS for GLMs with penalties
    • Supports all standard GLM families
    • Automatic weight computation and convergence checking
  4. reml.rs: Smoothing parameter selection

    • REML criterion for optimal smoothing
    • GCV criterion as alternative
    • Log-determinant computations
  5. smooth.rs: Smoothing parameter optimization

    • Coordinate descent optimization
    • Grid search for initialization
    • Works in log-space for numerical stability
  6. gam.rs: Main GAM model interface

    • Combines all components
    • Handles multiple smooth terms
    • Outer loop for lambda optimization
  7. linalg.rs: Linear algebra operations

    • Gaussian elimination with partial pivoting
    • Matrix inversion via Gauss-Jordan
    • Determinant computation via LU decomposition

Mathematical Background

GAM Model

g(E[Y]) = β₀ + f₁(x₁) + f₂(x₂) + ... + fₚ(xₚ)

Where:

  • g() is the link function
  • fᵢ() are smooth functions represented by basis expansions
  • Each smooth is penalized by λᵢ ∫ (f''ᵢ(x))² dx

PiRLS Algorithm

  1. Initialize: η = g(y)
  2. Until convergence:
    • Compute μ = g⁻¹(η)
    • Compute weights: w = (g'(μ))² / V(μ)
    • Compute working response: z = η + (y - μ) / g'(μ)
    • Solve: β = (X'WX + λS)⁻¹ X'Wz
    • Update: η = Xβ

REML Criterion

REML(λ) = n·log(RSS) + log|X'WX + λS| - log|S|

Minimized with respect to λ to select optimal smoothing parameters.

Examples

See examples/simple_gam.rs for a complete working example:

cargo run --example simple_gam --release

Project Structure

├── src/                    # Core Rust library code
├── examples/               # Rust usage examples
├── benches/               # Rust benchmarks
├── tests/                 # Rust tests
├── scripts/               # Testing and benchmarking scripts
│   ├── python/            # Python scripts
│   │   ├── tests/         # Python test scripts
│   │   └── benchmarks/    # Python benchmark scripts
│   └── r/                 # R scripts
│       ├── tests/         # R test scripts
│       └── benchmarks/    # R benchmark scripts
├── docs/                  # Documentation and analysis
└── test_data/            # Test data and results

Testing

Run the Rust test suite:

cargo test

All 20 tests should pass, covering:

  • Basis function evaluation
  • Penalty matrix construction
  • Linear algebra operations
  • REML/GCV criteria
  • PiRLS convergence
  • Full GAM fitting pipeline

Additional tests and benchmarks are available in the scripts/ directory.

Implementation Notes

  • TDD Approach: Every feature was implemented with tests first
  • No External Dependencies: Custom linear algebra to avoid BLAS/LAPACK issues
  • Numerical Stability: Operations performed in log-space where appropriate
  • Extensible Design: Easy to add new basis types, families, or criteria

Limitations & Future Work

  • Smoothing parameter optimization could be improved with better algorithms (e.g., Newton-Raphson)
  • Eigendecomposition for handling penalty null spaces more rigorously
  • Confidence intervals and standard errors
  • Model diagnostics and residual analysis
  • Tensor product smooths for multivariate terms
  • Parallel processing for large datasets

References

  • Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
  • Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. JRSS-B, 73(1), 3-36.

License

MIT License - see LICENSE file for details

Author

Implemented as a Rust port of R's mgcv package core functionality.

Update: REML Implementation Fixed! ✅

You were absolutely right - the REML implementation had bugs that caused it to always select λ ≈ 0.

What Was Wrong

  1. Singular Penalty Handling: REML was incorrectly handling rank-deficient penalty matrices, setting log|S| = 0 which broke the criterion
  2. Lambda Passing: Optimization was passing λ = 1.0 with pre-multiplied penalties, confusing the rank(S)*log(λ) term
  3. Insufficient Data: Examples used n=30 with p=15 (ratio 2:1), which is too small for REML/GCV

What Was Fixed

  1. REML Criterion: Now correctly uses log|λS| = rank(S)*log(λ) + constant
  2. Optimization: Passes actual λ values to criterion functions
  3. Data Size: Increased to n=300 for proper n/p ratio (20:1)
  4. REML Search: Uses fine grid search (gradient descent had issues)

Current Performance (n=300)

GCV:  λ = 0.067, Test RMSE = 0.480  ✅
REML: λ = 0.058, Test RMSE = 0.480  ✅

Both methods now select nearly optimal smoothing parameters!

See IMPLEMENTATION_SUMMARY.md for complete details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgcv_rust-0.4.0.tar.gz (2.6 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mgcv_rust-0.4.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (695.4 kB view details)

Uploaded PyPymanylinux: glibc 2.17+ x86-64

mgcv_rust-0.4.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (695.7 kB view details)

Uploaded CPython 3.14manylinux: glibc 2.17+ x86-64

mgcv_rust-0.4.0-cp314-cp314-macosx_11_0_arm64.whl (546.2 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mgcv_rust-0.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (696.4 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.17+ x86-64

mgcv_rust-0.4.0-cp313-cp313-macosx_11_0_arm64.whl (547.4 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mgcv_rust-0.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (696.6 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

mgcv_rust-0.4.0-cp312-cp312-macosx_11_0_arm64.whl (547.4 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mgcv_rust-0.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (695.6 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

mgcv_rust-0.4.0-cp311-cp311-macosx_11_0_arm64.whl (549.6 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mgcv_rust-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (697.9 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

mgcv_rust-0.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (700.7 kB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

mgcv_rust-0.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (700.4 kB view details)

Uploaded CPython 3.8manylinux: glibc 2.17+ x86-64

File details

Details for the file mgcv_rust-0.4.0.tar.gz.

File metadata

  • Download URL: mgcv_rust-0.4.0.tar.gz
  • Upload date:
  • Size: 2.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.10.2

File hashes

Hashes for mgcv_rust-0.4.0.tar.gz
Algorithm Hash digest
SHA256 e036aa86a9cf398c89693dda2a60f79ae8ee1f9cc81db25a1529e69f80a55dff
MD5 83ff734b944806f38a29f47996e25e5c
BLAKE2b-256 73dbaef1fa2550f082282f2772f43858cc4523c2e29e3c67391940fa3126c561

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.4.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.4.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 7dfa1eb2069a99a12d425282c8b70a22422e117210abfacbdd0eadca0ee40ee9
MD5 5d3fc9315881c6df17b3c45ed1efbf2a
BLAKE2b-256 ee1c1f0d2476051d4ca686198374cc78aff503253a57b0f850bd2d04adc93240

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.4.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.4.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 80d0c8e69c3475ff01403d30cf48bfc352a5b3c5640951192a115940dc09252a
MD5 341e01c8713b47fc7c2a8c57106d7e26
BLAKE2b-256 4469c1c9f49d0aa01bb98e400859a5fc7f7739bd3a69fab60eb3bc98b9104bc7

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.4.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.4.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 b1cf966498582f70f99311a6d4b52d1f071691c561d3edb9deb30dfbb7159323
MD5 654fe434b71da7b748f8dd76580c3598
BLAKE2b-256 61b269aa9ab9bf5af992b37630bf9fce43dcee72f227394264e666eadb28d581

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 386fd5a339ac060771d4817eb1305063f0d112103ff9f9e6b3026269beefc9f5
MD5 504398336a83b1f8978aab4600fd9c7c
BLAKE2b-256 c3cf68a4043c085bcf3792896b713c0cd25f8cf6eaf625de5cc964cca2362a39

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.4.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.4.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 e3140b120b6a304f790d661cd1eebaac94654d725c72592700d66e7070444441
MD5 3bbe1eea5cecc783f2dc6ee0ba668958
BLAKE2b-256 79823ccef1b9610c0b7ca981acdbae39e5318b62dbf97a4c9b86cde336fae47b

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 ab60ed90d9587962563584d9b69fcefd55bb4cf8bf46cc17e4079a96c732665f
MD5 201bdea80a9c888c437c6617d8c09c28
BLAKE2b-256 9f3fff9062ce6f83e1c315e9325ec52063c92878de79dd42295282aa7937808c

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.4.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.4.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 0119df1df7df99d75427f204fb6d38d78eebe063261b9dfcb2cebbb1adb39e18
MD5 9c08ddfdbfe0eb1caac3751b02ef4b34
BLAKE2b-256 96bf42fae7455b9d304b66118bec50bb603f3278d5a6a7a24edee72707dd66de

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 390acd403c7e39ca6702d590da81148ea28c7672ed43d23c2ea705b76cd79845
MD5 e9b020bd9afa66d5cc7288e235885a93
BLAKE2b-256 2312c4c2f3496877f94217c717eb8c1f3991f6840fb0c3ea571e78dd4548d8bf

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.4.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.4.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 a8857b19a9d4535bbf249387a15b9f3922d33140e67a8f1a1064483c240e346d
MD5 13bbf6eac4e5223605ac8fb75a28a61e
BLAKE2b-256 3a2a11f57d663dfcc21d9655a3268305ec7e265d5f18dc02e5e4c430b6ba34a3

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 dbb4b088484d76cf2e7ba1e46f2b5e0855810c445b090361be9d870f6ab73ac8
MD5 ae06e4a45e0980855e8161d26d291ae3
BLAKE2b-256 206fabde8a8d26eb36a98f2bb6af650b38f7d7d797ba9acfb587b66b446b079b

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 f1b4ecc25b4b70e12de331a220233852b0a93bfd4d9d6d57cd3e5a38e97673fa
MD5 acfe54d8478602ccdb0a9dca317ed427
BLAKE2b-256 1d4e70e52052fcab2ddf4ce96ebe722cf713479d7b80212629f4226c267205c5

See more details on using hashes here.

File details

Details for the file mgcv_rust-0.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for mgcv_rust-0.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 e48e396ad3775f07b5b23f964fe4751501dbd924140760710d39f57385d552e0
MD5 a6968c294c37b2d6f36f51510685ee99
BLAKE2b-256 e630b223362e6d94f1f3d3e2f707ea3fbc8f3120640bb5c72899a7282c21c089

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page