Skip to main content

Fast Generalized Linear Models with a Rust backend - statsmodels compatible

Project description

RustyStats ๐Ÿฆ€๐Ÿ“Š

High-performance Generalized Linear Models with a Rust backend and Python API

Codebase Documentation: pricingfrontier.github.io/rustystats/

Performance Benchmarks

RustyStats vs Statsmodels โ€” Synthetic data, 101 features (10 continuous + 10 categorical with 10 levels each).

Family 10K rows 250K rows 500K rows
Gaussian 15.6x 5.7x 4.3x
Poisson 16.3x 6.2x 4.2x
Binomial 19.5x 6.8x 4.4x
Gamma 33.7x 13.4x 8.4x
NegBinomial 26.7x 6.7x 5.0x

Average speedup: 10.5x (range: 4.2x โ€“ 33.7x)

Memory Usage

RustyStats uses significantly less RAM by reusing buffers and avoiding Python object overhead:

Rows RustyStats Statsmodels Reduction
10K 38 MB 72 MB 1.9x
250K 460 MB 1,796 MB 3.9x
500K 836 MB 3,590 MB 4.3x

Memory advantage grows with data size โ€” at 500K rows, RustyStats uses ~4x less RAM.

Full benchmark details
Family Rows RustyStats Statsmodels Speedup
Gaussian 10,000 0.100s 1.559s 15.6x
Gaussian 250,000 1.991s 11.363s 5.7x
Gaussian 500,000 4.023s 17.386s 4.3x
Poisson 10,000 0.165s 2.692s 16.3x
Poisson 250,000 2.429s 15.072s 6.2x
Poisson 500,000 5.668s 23.693s 4.2x
Binomial 10,000 0.112s 2.189s 19.5x
Binomial 250,000 1.946s 13.155s 6.8x
Binomial 500,000 4.708s 20.862s 4.4x
Gamma 10,000 0.129s 4.353s 33.7x
Gamma 250,000 2.385s 31.885s 13.4x
Gamma 500,000 5.499s 46.167s 8.4x
NegBinomial 10,000 0.119s 3.177s 26.7x
NegBinomial 250,000 2.281s 15.278s 6.7x
NegBinomial 500,000 4.821s 24.331s 5.0x

Times are median of 3 runs. Benchmark scripts in benchmarks/.


Features

  • Fast - Parallel Rust backend, 4-30x faster than statsmodels
  • Memory Efficient - 4x less RAM than statsmodels at scale
  • Stable - Step-halving IRLS, warm starts for robust convergence
  • Splines - B-splines bs() and natural splines ns() in formulas
  • Polynomials - Identity terms I(x ** 2) for polynomial and arithmetic expressions
  • Target Encoding - CatBoost-style TE() for high-cardinality categoricals (exposure-aware)
  • Regularisation - Ridge, Lasso, and Elastic Net via coordinate descent
  • Validation - Design matrix checks with fix suggestions before fitting
  • Complete - 8 families, robust SEs, full diagnostics, VIF, partial dependence
  • Minimal - Only numpy and polars required

Installation

uv add rustystats

Quick Start

import rustystats as rs
import polars as pl

# Load data
data = pl.read_parquet("insurance.parquet")

# Fit a Poisson GLM for claim frequency
result = rs.glm(
    "ClaimCount ~ VehAge + VehPower + C(Area) + C(Region)",
    data=data,
    family="poisson",
    offset="Exposure"
).fit()

# View results
print(result.summary())

Families & Links

Family Default Link Use Case
gaussian identity Linear regression
poisson log Claim frequency
binomial logit Binary outcomes
gamma log Claim severity
tweedie log Pure premium (var_power=1.5)
quasipoisson log Overdispersed counts
quasibinomial logit Overdispersed binary
negbinomial log Overdispersed counts (proper distribution)

Formula Syntax

# Main effects
"y ~ x1 + x2 + C(category)"

# Single-level categorical indicators
"y ~ C(Region, level='Paris')"              # 0/1 indicator for Paris only
"y ~ C(Region, levels=['Paris', 'Lyon'])"   # Indicators for specific levels

# Interactions
"y ~ x1*x2"              # x1 + x2 + x1:x2
"y ~ C(area):age"        # Area-specific age effects
"y ~ C(area)*C(brand)"   # Categorical ร— categorical

# Splines (non-linear effects)
"y ~ bs(age, df=5)"      # B-spline basis
"y ~ ns(income, df=4)"   # Natural spline (better extrapolation)

# Identity terms (polynomial/arithmetic expressions)
"y ~ I(age ** 2)"        # Polynomial terms
"y ~ I(x1 * x2)"         # Explicit products
"y ~ I(income / 1000)"   # Scaled variables

# Target encoding (high-cardinality categoricals)
"y ~ TE(brand) + TE(model)"

# Combined
"y ~ bs(age, df=5) + C(region)*income + ns(vehicle_age, df=3) + TE(brand) + I(age ** 2)"

Results Methods

# Coefficients & Inference
result.params              # Coefficients
result.fittedvalues        # Predicted means
result.deviance            # Model deviance
result.bse()               # Standard errors
result.tvalues()           # z-statistics
result.pvalues()           # P-values
result.conf_int(alpha)     # Confidence intervals

# Robust Standard Errors (sandwich estimators)
result.bse_robust("HC1")   # Robust SE (HC0, HC1, HC2, HC3)
result.tvalues_robust()    # z-stats with robust SE
result.pvalues_robust()    # P-values with robust SE
result.conf_int_robust()   # Confidence intervals with robust SE
result.cov_robust()        # Full robust covariance matrix

# Diagnostics (statsmodels-compatible)
result.resid_response()    # Raw residuals (y - ฮผ)
result.resid_pearson()     # Pearson residuals
result.resid_deviance()    # Deviance residuals
result.resid_working()     # Working residuals
result.llf()               # Log-likelihood
result.aic()               # Akaike Information Criterion
result.bic()               # Bayesian Information Criterion
result.null_deviance()     # Null model deviance
result.pearson_chi2()      # Pearson chi-squared
result.scale()             # Dispersion (deviance-based)
result.scale_pearson()     # Dispersion (Pearson-based)
result.family              # Family name

Regularization

CV-Based Regularization (Recommended)

# Just specify regularization type - cv=5 is automatic
result = rs.glm("y ~ x1 + x2 + C(cat)", data, family="poisson").fit(
    regularization="ridge"  # "ridge", "lasso", or "elastic_net"
)

print(f"Selected alpha: {result.alpha}")
print(f"CV deviance: {result.cv_deviance}")

Options:

  • regularization: "ridge" (L2), "lasso" (L1), or "elastic_net" (mix)
  • selection: "min" (best fit) or "1se" (more conservative, default: "min")
  • cv: Number of folds (default: 5)

Explicit Alpha

# Skip CV, use specific alpha
result = rs.glm("y ~ x1 + x2", data).fit(alpha=0.1, l1_ratio=0.0)  # Ridge
result = rs.glm("y ~ x1 + x2", data).fit(alpha=0.1, l1_ratio=1.0)  # Lasso
result = rs.glm("y ~ x1 + x2", data).fit(alpha=0.1, l1_ratio=0.5)  # Elastic Net

Interaction Terms

# Continuous ร— Continuous interaction (main effects + interaction)
result = rs.glm(
    "ClaimNb ~ Age*VehPower",  # Equivalent to Age + VehPower + Age:VehPower
    data, family="poisson", offset="Exposure"
).fit()

# Categorical ร— Continuous interaction
result = rs.glm(
    "ClaimNb ~ C(Area)*Age",  # Each area level has different age effect
    data, family="poisson", offset="Exposure"
).fit()

# Categorical ร— Categorical interaction
result = rs.glm(
    "ClaimNb ~ C(Area)*C(VehBrand)",
    data, family="poisson", offset="Exposure"
).fit()

# Pure interaction (no main effects added)
result = rs.glm(
    "ClaimNb ~ Age + C(Area):VehPower",  # Area-specific VehPower slopes
    data, family="poisson", offset="Exposure"
).fit()

Spline Basis Functions

# Use splines in formulas - automatic parsing
result = rs.glm(
    "ClaimNb ~ bs(Age, df=5) + ns(VehPower, df=4) + C(Region)",
    data=data,
    family="poisson",
    offset="Exposure"
).fit()

# Combine splines with interactions
result = rs.glm(
    "y ~ bs(age, df=4)*C(gender) + ns(income, df=3)",
    data=data,
    family="gaussian"
).fit()

# Direct basis computation for custom use
import numpy as np
x = np.linspace(0, 10, 100)
basis = rs.bs(x, df=5)  # 5 degrees of freedom (4 basis columns)
basis_ns = rs.ns(x, df=5)  # Natural splines - linear extrapolation at boundaries

When to use each spline type:

  • B-splines (bs): Standard choice, more flexible at boundaries
  • Natural splines (ns): Better extrapolation, linear beyond boundaries (recommended for actuarial work)

Quasi-Families for Overdispersion

# Fit a standard Poisson model first
result_poisson = rs.glm("ClaimNb ~ Age + C(Region)", data, family="poisson", offset="Exposure").fit()

# Check for overdispersion: Pearson ฯ‡ยฒ / df >> 1 indicates overdispersion
dispersion_ratio = result_poisson.pearson_chi2() / result_poisson.df_resid
print(f"Dispersion ratio: {dispersion_ratio:.2f}")  # If >> 1, use quasi-family

# Fit QuasiPoisson if overdispersed
result_quasi = rs.glm("ClaimNb ~ Age + C(Region)", data, family="quasipoisson", offset="Exposure").fit()

# Coefficients are IDENTICAL to Poisson, but standard errors are inflated by โˆšฯ†
print(f"Estimated dispersion (ฯ†): {result_quasi.scale():.3f}")

# For binary data with overdispersion
result_qb = rs.glm("Binary ~ x1 + x2", data, family="quasibinomial").fit()

Key properties of quasi-families:

  • Point estimates: Identical to base family (Poisson/Binomial)
  • Standard errors: Inflated by โˆšฯ† where ฯ† = Pearson ฯ‡ยฒ/(n-p)
  • P-values: More conservative (larger), accounting for extra variance

Negative Binomial for Overdispersed Counts

# Automatic ฮธ estimation (default when theta not supplied)
result = rs.glm("ClaimNb ~ Age + C(Region)", data, family="negbinomial", offset="Exposure").fit()
print(result.family)  # "NegativeBinomial(theta=2.1234)"

# Fixed ฮธ value
result = rs.glm("ClaimNb ~ Age + C(Region)", data, family="negbinomial", theta=1.0, offset="Exposure").fit()

# ฮธ controls overdispersion: Var(Y) = ฮผ + ฮผยฒ/ฮธ
# - ฮธ=0.5: Strong overdispersion (variance = ฮผ + 2ฮผยฒ)
# - ฮธ=1.0: Moderate overdispersion (variance = ฮผ + ฮผยฒ)
# - ฮธโ†’โˆž: Approaches Poisson (variance = ฮผ)

NegativeBinomial vs QuasiPoisson:

Aspect QuasiPoisson NegativeBinomial
Variance ฯ† ร— ฮผ ฮผ + ฮผยฒ/ฮธ
True distribution No (quasi) Yes
AIC/BIC valid Questionable Yes
Prediction intervals Not principled Proper

Target Encoding for High-Cardinality Categoricals

# Formula API - TE() in formulas
result = rs.glm(
    "ClaimNb ~ TE(Brand) + TE(Model) + Age + C(Region)",
    data=data,
    family="poisson",
    offset="Exposure"
).fit()

# With options
result = rs.glm(
    "y ~ TE(brand, prior_weight=2.0, n_permutations=8) + age",
    data=data,
    family="gaussian"
).fit()

# Sklearn-style API
encoder = rs.TargetEncoder(prior_weight=1.0, n_permutations=4)
train_encoded = encoder.fit_transform(train_categories, train_target)
test_encoded = encoder.transform(test_categories)

Key benefits:

  • No target leakage: Ordered target statistics
  • Regularization: Prior weight controls shrinkage toward global mean
  • High-cardinality: Single column instead of thousands of dummies
  • Exposure-aware: For frequency models with offset="Exposure", TE() automatically uses claim rate (ClaimCount/Exposure) instead of raw counts, preventing near-constant encoded values

Identity Terms for Polynomials

# Polynomial terms
result = rs.glm(
    "y ~ age + I(age ** 2) + I(age ** 3)",
    data=data,
    family="gaussian"
).fit()

# Arithmetic expressions
result = rs.glm(
    "y ~ I(income / 1000) + I(weight * height)",
    data=data,
    family="gaussian"
).fit()

Supported operations: +, -, *, /, ** (power)


Design Matrix Validation

# Check for issues before fitting
model = rs.glm("y ~ ns(x, df=4) + C(cat)", data, family="poisson")
results = model.validate()  # Prints diagnostics

if not results['valid']:
    print("Issues:", results['suggestions'])

# Validation runs automatically on fit failure with helpful suggestions

Checks performed:

  • Rank deficiency (linearly dependent columns)
  • High multicollinearity (condition number)
  • Zero variance columns
  • NaN/Inf values
  • Highly correlated column pairs (>0.999)

Model Diagnostics

# Compute all diagnostics at once
diagnostics = result.diagnostics(
    data=data,
    categorical_factors=["Region", "VehBrand", "Area"],  # Including non-fitted
    continuous_factors=["Age", "Income", "VehPower"],    # Including non-fitted
)

# Export as compact JSON (optimized for LLM consumption)
json_str = diagnostics.to_json()

# Pre-fit data exploration (no model needed)
exploration = rs.explore_data(
    data=data,
    response="ClaimNb",
    categorical_factors=["Region", "VehBrand", "Area"],
    continuous_factors=["Age", "VehPower", "Income"],
    exposure="Exposure",
    family="poisson",
    detect_interactions=True,
)

Diagnostic Features:

  • Calibration: Overall A/E ratio, calibration by decile with CIs, Hosmer-Lemeshow test
  • Discrimination: Gini coefficient, AUC, KS statistic, lift metrics
  • Factor Diagnostics: A/E by level/bin for ALL factors (fitted and non-fitted)
  • VIF/Multicollinearity: Variance inflation factors for design matrix columns
  • Partial Dependence: Effect plots with shape detection and recommendations
  • Overfitting Detection: Compare train vs test metrics when test data provided
  • Interaction Detection: Greedy residual-based detection of potential interactions
  • Warnings: Auto-generated alerts for high dispersion, poor calibration, missing factors

RustyStats vs Statsmodels

Feature RustyStats Statsmodels
Parallel IRLS Solver โœ… Multi-threaded โŒ Single-threaded only
Native Polars Support โœ… Polars only โŒ Pandas only
Built-in Lasso/Elastic Net for GLMs โœ… Fast coordinate descent with all families โš ๏ธ Limited
Relativities Table โœ… result.relativities() for pricing โŒ Must compute manually
Robust Standard Errors โœ… HC0, HC1, HC2, HC3 sandwich estimators โœ… HC0-HC3

Project Structure

rustystats/
โ”œโ”€โ”€ Cargo.toml                    # Workspace config
โ”œโ”€โ”€ pyproject.toml                # Python package config
โ”‚
โ”œโ”€โ”€ crates/
โ”‚   โ”œโ”€โ”€ rustystats-core/          # Pure Rust GLM library
โ”‚   โ”‚   โ””โ”€โ”€ src/
โ”‚   โ”‚       โ”œโ”€โ”€ families/         # Gaussian, Poisson, Binomial, Gamma, Tweedie, Quasi, NegativeBinomial
โ”‚   โ”‚       โ”œโ”€โ”€ links/            # Identity, Log, Logit
โ”‚   โ”‚       โ”œโ”€โ”€ solvers/          # IRLS, coordinate descent
โ”‚   โ”‚       โ”œโ”€โ”€ inference/        # P-values, CIs, robust SE (HC0-HC3)
โ”‚   โ”‚       โ”œโ”€โ”€ interactions/     # Lazy interaction term computation
โ”‚   โ”‚       โ”œโ”€โ”€ splines/          # B-spline and natural spline basis functions
โ”‚   โ”‚       โ”œโ”€โ”€ design_matrix/    # Categorical encoding, interaction matrices
โ”‚   โ”‚       โ”œโ”€โ”€ formula/          # R-style formula parsing
โ”‚   โ”‚       โ”œโ”€โ”€ target_encoding/  # Ordered target statistics
โ”‚   โ”‚       โ””โ”€โ”€ diagnostics/      # Residuals, dispersion, AIC/BIC, calibration, loss
โ”‚   โ”‚
โ”‚   โ””โ”€โ”€ rustystats/               # Python bindings (PyO3)
โ”‚       โ””โ”€โ”€ src/lib.rs
โ”‚
โ”œโ”€โ”€ python/rustystats/            # Python package
โ”‚   โ”œโ”€โ”€ __init__.py               # Main exports
โ”‚   โ”œโ”€โ”€ formula.py                # Formula API with DataFrame support
โ”‚   โ”œโ”€โ”€ interactions.py           # Interaction terms, I() expressions, design matrix
โ”‚   โ”œโ”€โ”€ splines.py                # bs() and ns() spline basis functions
โ”‚   โ”œโ”€โ”€ target_encoding.py        # Target encoding (exposure-aware)
โ”‚   โ”œโ”€โ”€ diagnostics.py            # Model diagnostics with JSON export
โ”‚   โ””โ”€โ”€ families.py               # Family wrappers
โ”‚
โ”œโ”€โ”€ examples/
โ”‚   โ””โ”€โ”€ frequency.ipynb           # Claim frequency example
โ”‚
โ””โ”€โ”€ tests/python/                 # Python test suite

Dependencies

Rust

  • ndarray, nalgebra - Linear algebra
  • rayon - Parallel iterators (multi-threading)
  • statrs - Statistical distributions
  • pyo3 - Python bindings

Python

  • numpy - Array operations (required)
  • polars - DataFrame support (required)

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rustystats-0.1.13.tar.gz (202.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rustystats-0.1.13-cp313-cp313-manylinux_2_34_x86_64.whl (869.9 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.34+ x86-64

File details

Details for the file rustystats-0.1.13.tar.gz.

File metadata

  • Download URL: rustystats-0.1.13.tar.gz
  • Upload date:
  • Size: 202.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.10.2

File hashes

Hashes for rustystats-0.1.13.tar.gz
Algorithm Hash digest
SHA256 18d088f0c5ed86ceaa083ff7b2ac8a3682dbef26d90feeb2f475f4de389ce25f
MD5 0476f6bf43523371c6818f26c9d5e2a8
BLAKE2b-256 16d59c9844b8139f4e607fdbf74c54f27034a6428a0cee5f7e51aeae0ea0a100

See more details on using hashes here.

File details

Details for the file rustystats-0.1.13-cp313-cp313-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for rustystats-0.1.13-cp313-cp313-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 0ad5643bd51f461b6ddfe14185824a12260c677ab62bd324a3b6b7e970472538
MD5 724038ba3a39fcd07d281b46baa26756
BLAKE2b-256 8d3593e702fdbb2c2e1c7152aef9d005449f400b0453b2e8b7883fc03bcf7804

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page