Rust implementation of Generalized Additive Models with Python bindings
Project description
mgcv_rust: Generalized Additive Models in Rust
A Rust implementation of Generalized Additive Models (GAMs) with automatic smoothing parameter selection using REML (Restricted Maximum Likelihood) and the PiRLS (Penalized Iteratively Reweighted Least Squares) algorithm, inspired by R's mgcv package.
Features
- Multiple Distribution Families: Gaussian, Binomial, Poisson, and Gamma
- Flexible Basis Functions:
- Cubic B-splines with natural boundary conditions
- Thin plate splines for smooth multivariate regression
- Automatic Smoothing:
- REML (Restricted Maximum Likelihood) criterion
- GCV (Generalized Cross-Validation) criterion
- PiRLS Algorithm: Efficient fitting via Penalized Iteratively Reweighted Least Squares
- Pure Rust: No external BLAS/LAPACK dependencies
- Test-Driven Development: Comprehensive test suite with 20+ passing tests
Installation
Add to your Cargo.toml:
[dependencies]
mgcv_rust = { path = "." }
ndarray = "0.16"
Quick Start
Python (Recommended)
import numpy as np
from mgcv_rust import GAM
# Generate data: y = sin(2πx) + noise
X = np.random.uniform(0, 1, (500, 2))
y = np.sin(2 * np.pi * X[:, 0]) + 0.5 * (X[:, 1] - 0.5)**2
# Fit GAM with automatic smooth setup
gam = GAM()
result = gam.fit(X, y, k=[10, 15]) # That's it!
print(f"Lambda values: {result['lambda']}")
print(f"Deviance: {result['deviance']}")
# Make predictions
X_test = np.random.uniform(0, 1, (100, 2))
predictions = gam.predict(X_test)
Performance: 1.5x - 65x faster than R's mgcv (problem-dependent)
See API_SIMPLIFICATION.md for more details on the simplified Python API.
Rust
use mgcv_rust::{GAM, Family, SmoothTerm, OptimizationMethod};
use ndarray::{Array1, Array2};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Generate data: y = sin(2πx) + noise
let n = 100;
let x_data: Vec<f64> = (0..n).map(|i| i as f64 / n as f64).collect();
let y_data: Vec<f64> = x_data
.iter()
.map(|&xi| (2.0 * std::f64::consts::PI * xi).sin() + noise())
.collect();
let x = Array1::from_vec(x_data);
let y = Array1::from_vec(y_data);
let x_matrix = x.into_shape((n, 1))?;
// Create GAM with cubic spline smooth
let mut gam = GAM::new(Family::Gaussian);
let smooth = SmoothTerm::cubic_spline("x".to_string(), 20, 0.0, 1.0)?;
gam.add_smooth(smooth);
// Fit with REML smoothing parameter selection
gam.fit(
&x_matrix,
&y,
OptimizationMethod::REML,
5, // max outer iterations
50, // max inner iterations (PiRLS)
1e-4 // convergence tolerance
)?;
// Make predictions
let predictions = gam.predict(&x_test)?;
Ok(())
}
Architecture
Core Components
-
basis.rs: Basis function implementationsCubicSpline: Cubic B-spline basis with configurable knotsThinPlateSpline: Radial basis functions for smooth regression
-
penalty.rs: Penalty matrix construction- Second derivative penalties for smoothness
- Supports multiple penalty types per basis
-
pirls.rs: Penalized IRLS fitting algorithm- Implements PiRLS for GLMs with penalties
- Supports all standard GLM families
- Automatic weight computation and convergence checking
-
reml.rs: Smoothing parameter selection- REML criterion for optimal smoothing
- GCV criterion as alternative
- Log-determinant computations
-
smooth.rs: Smoothing parameter optimization- Coordinate descent optimization
- Grid search for initialization
- Works in log-space for numerical stability
-
gam.rs: Main GAM model interface- Combines all components
- Handles multiple smooth terms
- Outer loop for lambda optimization
-
linalg.rs: Linear algebra operations- Gaussian elimination with partial pivoting
- Matrix inversion via Gauss-Jordan
- Determinant computation via LU decomposition
Mathematical Background
GAM Model
g(E[Y]) = β₀ + f₁(x₁) + f₂(x₂) + ... + fₚ(xₚ)
Where:
g()is the link functionfᵢ()are smooth functions represented by basis expansions- Each smooth is penalized by
λᵢ ∫ (f''ᵢ(x))² dx
PiRLS Algorithm
- Initialize: η = g(y)
- Until convergence:
- Compute μ = g⁻¹(η)
- Compute weights: w = (g'(μ))² / V(μ)
- Compute working response: z = η + (y - μ) / g'(μ)
- Solve: β = (X'WX + λS)⁻¹ X'Wz
- Update: η = Xβ
REML Criterion
REML(λ) = n·log(RSS) + log|X'WX + λS| - log|S|
Minimized with respect to λ to select optimal smoothing parameters.
Examples
See examples/simple_gam.rs for a complete working example:
cargo run --example simple_gam --release
Project Structure
├── src/ # Core Rust library code
├── examples/ # Rust usage examples
├── benches/ # Rust benchmarks
├── tests/ # Rust tests
├── scripts/ # Testing and benchmarking scripts
│ ├── python/ # Python scripts
│ │ ├── tests/ # Python test scripts
│ │ └── benchmarks/ # Python benchmark scripts
│ └── r/ # R scripts
│ ├── tests/ # R test scripts
│ └── benchmarks/ # R benchmark scripts
├── docs/ # Documentation and analysis
└── test_data/ # Test data and results
Testing
Run the Rust test suite:
cargo test
All 20 tests should pass, covering:
- Basis function evaluation
- Penalty matrix construction
- Linear algebra operations
- REML/GCV criteria
- PiRLS convergence
- Full GAM fitting pipeline
Additional tests and benchmarks are available in the scripts/ directory.
Implementation Notes
- TDD Approach: Every feature was implemented with tests first
- No External Dependencies: Custom linear algebra to avoid BLAS/LAPACK issues
- Numerical Stability: Operations performed in log-space where appropriate
- Extensible Design: Easy to add new basis types, families, or criteria
Limitations & Future Work
- Smoothing parameter optimization could be improved with better algorithms (e.g., Newton-Raphson)
- Eigendecomposition for handling penalty null spaces more rigorously
- Confidence intervals and standard errors
- Model diagnostics and residual analysis
- Tensor product smooths for multivariate terms
- Parallel processing for large datasets
References
- Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
- Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. JRSS-B, 73(1), 3-36.
License
MIT License - see LICENSE file for details
Author
Implemented as a Rust port of R's mgcv package core functionality.
Update: REML Implementation Fixed! ✅
You were absolutely right - the REML implementation had bugs that caused it to always select λ ≈ 0.
What Was Wrong
- Singular Penalty Handling: REML was incorrectly handling rank-deficient penalty matrices, setting
log|S| = 0which broke the criterion - Lambda Passing: Optimization was passing
λ = 1.0with pre-multiplied penalties, confusing therank(S)*log(λ)term - Insufficient Data: Examples used n=30 with p=15 (ratio 2:1), which is too small for REML/GCV
What Was Fixed
- REML Criterion: Now correctly uses
log|λS| = rank(S)*log(λ) + constant - Optimization: Passes actual λ values to criterion functions
- Data Size: Increased to n=300 for proper n/p ratio (20:1)
- REML Search: Uses fine grid search (gradient descent had issues)
Current Performance (n=300)
GCV: λ = 0.067, Test RMSE = 0.480 ✅
REML: λ = 0.058, Test RMSE = 0.480 ✅
Both methods now select nearly optimal smoothing parameters!
See IMPLEMENTATION_SUMMARY.md for complete details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mgcv_rust-0.2.0.tar.gz.
File metadata
- Download URL: mgcv_rust-0.2.0.tar.gz
- Upload date:
- Size: 1.6 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: maturin/1.10.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6a204145520d06c96f8ced8f0bfa38b1cfadbdad00411920341800e55b81e1f7
|
|
| MD5 |
b03eb9d6cc01aec62316342edc53d004
|
|
| BLAKE2b-256 |
78225a341be42dc991d77ba9c4308db4ac943d5a57c4e732ddbc6d968c0bea4c
|
File details
Details for the file mgcv_rust-0.2.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.
File metadata
- Download URL: mgcv_rust-0.2.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- Upload date:
- Size: 693.9 kB
- Tags: PyPy, manylinux: glibc 2.17+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: maturin/1.10.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
75c9395083941ed88d895e2c7a58971692ecccf3f75189333402f9c6b094e35b
|
|
| MD5 |
57c18263b8fe8788af9c50d63db0f3f3
|
|
| BLAKE2b-256 |
1e05fdfab31446eb7092a6ec8926b7a737bf11f4551a6726619d4f3ade7f75cf
|
File details
Details for the file mgcv_rust-0.2.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.
File metadata
- Download URL: mgcv_rust-0.2.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- Upload date:
- Size: 693.9 kB
- Tags: CPython 3.14, manylinux: glibc 2.17+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: maturin/1.10.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e85484be95dc5252562dff7c81000414009e54d2a7fcd391e1b64d080707da6e
|
|
| MD5 |
10af103292426b55ee84e76553f5d4e2
|
|
| BLAKE2b-256 |
72f64ed1648db3011bd1ab8f2248facdc3e538dfdfa2197728c4d93e6c8cfd5e
|
File details
Details for the file mgcv_rust-0.2.0-cp314-cp314-macosx_11_0_arm64.whl.
File metadata
- Download URL: mgcv_rust-0.2.0-cp314-cp314-macosx_11_0_arm64.whl
- Upload date:
- Size: 545.1 kB
- Tags: CPython 3.14, macOS 11.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: maturin/1.10.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2a93155beb1bc35a076e88026f6b570f4af067b08bea0496e54ed140df148f98
|
|
| MD5 |
98b5df02237aebe95303aea6cd2fcf20
|
|
| BLAKE2b-256 |
10f5ef042dbf6fcf47d9ef140fedfd2d529a23126705ae3f3a0c3b0d9ed01428
|
File details
Details for the file mgcv_rust-0.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.
File metadata
- Download URL: mgcv_rust-0.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- Upload date:
- Size: 694.2 kB
- Tags: CPython 3.13, manylinux: glibc 2.17+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: maturin/1.10.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b8e2713c92aeae45375a5a0dda0b1f9610409d21d2cf444e40af2342901cb8af
|
|
| MD5 |
bafb87eb888305d459bf6c69a29290ec
|
|
| BLAKE2b-256 |
4a8c097ae58a5c3fb9ea8d7984b63053915d40d34bf0b1e57296686e7479da0e
|
File details
Details for the file mgcv_rust-0.2.0-cp313-cp313-macosx_11_0_arm64.whl.
File metadata
- Download URL: mgcv_rust-0.2.0-cp313-cp313-macosx_11_0_arm64.whl
- Upload date:
- Size: 545.8 kB
- Tags: CPython 3.13, macOS 11.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: maturin/1.10.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4b55d1b237f65b96da5d662dc12c83213a7a6aae0304ba5e0c731867867bf04f
|
|
| MD5 |
bb500494cd06b74b209f8081b21d5d79
|
|
| BLAKE2b-256 |
0aa22f756b25bcfbfd3831839c4eaf79b6ec3cec883ff9535ba570b99ec9ac53
|
File details
Details for the file mgcv_rust-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.
File metadata
- Download URL: mgcv_rust-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- Upload date:
- Size: 694.2 kB
- Tags: CPython 3.12, manylinux: glibc 2.17+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: maturin/1.10.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e61fd77b291a542db897e19ba1c3f87ccb982fbad9ae4822793d3bc1c28cbbb0
|
|
| MD5 |
678e0455ff5d010d084b18f011336ede
|
|
| BLAKE2b-256 |
25cafc103a84b2bbb50f5579a0562887ebac36fd3cd68f287534869810c4fbbd
|
File details
Details for the file mgcv_rust-0.2.0-cp312-cp312-macosx_11_0_arm64.whl.
File metadata
- Download URL: mgcv_rust-0.2.0-cp312-cp312-macosx_11_0_arm64.whl
- Upload date:
- Size: 545.9 kB
- Tags: CPython 3.12, macOS 11.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: maturin/1.10.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2b0471e3ae85a9c7cf0d0758c2d3e09f2571e5eed4156fdac021ca88a040e3d5
|
|
| MD5 |
f0a0ffeb029c6c1057b1ff5f467b239e
|
|
| BLAKE2b-256 |
f3e061da7049e28a40038d7676a682c19fb43438a15b6b37ec50e447ca93f438
|
File details
Details for the file mgcv_rust-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.
File metadata
- Download URL: mgcv_rust-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- Upload date:
- Size: 694.0 kB
- Tags: CPython 3.11, manylinux: glibc 2.17+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: maturin/1.10.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
982dfda3e9b43621c2030bdf06157daae384685edb33551e6c831e1e1f64d37b
|
|
| MD5 |
359b9875c301fb0cf6af9298f76889dd
|
|
| BLAKE2b-256 |
f15792aa1616546abe1b6e52f4a02d057216f52c5d5c088ce2d708850d038368
|
File details
Details for the file mgcv_rust-0.2.0-cp311-cp311-macosx_11_0_arm64.whl.
File metadata
- Download URL: mgcv_rust-0.2.0-cp311-cp311-macosx_11_0_arm64.whl
- Upload date:
- Size: 548.2 kB
- Tags: CPython 3.11, macOS 11.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: maturin/1.10.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a2b24c80bdfe7f8597dc4891f40365af8af5b700be5ed77eaa516e46b71dd5a5
|
|
| MD5 |
7bcc4c06fc21e292cc91059716f288f6
|
|
| BLAKE2b-256 |
7f5107483abd46c746b1db261ce32727e895619276650c04455657996d858633
|
File details
Details for the file mgcv_rust-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.
File metadata
- Download URL: mgcv_rust-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- Upload date:
- Size: 695.5 kB
- Tags: CPython 3.10, manylinux: glibc 2.17+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: maturin/1.10.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4fe72d8004818d087e75d836e9668feb4640421cc59595e3cd855fde955cd760
|
|
| MD5 |
1fc61bb8c86ac91f68482edd7766fe13
|
|
| BLAKE2b-256 |
525c730bbede43543dc47d33eca739b6a710fc1230bc7cb412a72f40ffde4b92
|
File details
Details for the file mgcv_rust-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.
File metadata
- Download URL: mgcv_rust-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- Upload date:
- Size: 698.6 kB
- Tags: CPython 3.9, manylinux: glibc 2.17+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: maturin/1.10.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7f6eb4d143f3b3cfdfa2c49c9692949f87676c3e7ed73a0facc9bf6dd77d0ca8
|
|
| MD5 |
df1eb63b88a70fbc641f6d457f608663
|
|
| BLAKE2b-256 |
a718f2a020c2debf1149c36ff2b464ebd487667c17fe56b8dd74552feb148ba8
|
File details
Details for the file mgcv_rust-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.
File metadata
- Download URL: mgcv_rust-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- Upload date:
- Size: 698.4 kB
- Tags: CPython 3.8, manylinux: glibc 2.17+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: maturin/1.10.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
305c9fe96fdf2dd538f0b58b417c253e3daafe965acc5b7729467731aa42cac9
|
|
| MD5 |
a8b2edc7ec87129b94856eb0b7deef37
|
|
| BLAKE2b-256 |
f26c56952f0b10dcb252a75ac2661b2713437a4de945de46aaae9bfbeda0447d
|