Skip to main content

Fast Generalized Linear Models with a Rust backend - statsmodels compatible

Project description

RustyStats ๐Ÿฆ€๐Ÿ“Š

High-performance Generalized Linear Models with a Rust backend and Python API

Built for actuarial applications. Fits 678K rows in ~1 second.

Features

  • Fast โ€” Parallel IRLS solver in Rust (Rayon)
  • Complete โ€” Families, regularization, inference, diagnostics
  • Flexible โ€” R-style formulas with interactions and splines
  • Minimal โ€” Core requires only numpy and polars

Installation

# Development install
git clone https://github.com/PricingFrontier/rustystats.git
cd rustystats
uv run maturin develop --release

# Run tests
uv run pytest tests/python/

Quick Start

import rustystats as rs
import polars as pl

# Load data
data = pl.read_parquet("insurance.parquet")

# Fit a Poisson GLM for claim frequency
result = rs.glm(
    "ClaimCount ~ VehAge + VehPower + C(Area) + C(Region)",
    data=data,
    family="poisson",
    offset="Exposure"
).fit()

# View results
print(result.summary())
print(result.relativities())  # exp(coef) for pricing

Families & Links

Family Default Link Use Case
gaussian identity Linear regression
poisson log Claim frequency
binomial logit Binary outcomes
gamma log Claim severity
tweedie log Pure premium (var_power=1.5)
quasipoisson log Overdispersed counts
quasibinomial logit Overdispersed binary
negbinomial log Overdispersed counts (proper distribution)

Formula Syntax

# Main effects
"y ~ x1 + x2 + C(category)"

# Interactions
"y ~ x1*x2"              # x1 + x2 + x1:x2
"y ~ C(area):age"        # Area-specific age effects
"y ~ C(area)*C(brand)"   # Categorical ร— categorical

# Splines (non-linear effects)
"y ~ bs(age, df=5)"      # B-spline basis
"y ~ ns(income, df=4)"   # Natural spline (better extrapolation)

# Target encoding (high-cardinality categoricals)
"y ~ TE(brand) + TE(model)"

# Combined
"y ~ bs(age, df=5) + C(region)*income + ns(vehicle_age, df=3) + TE(brand)"

Results Methods

# Coefficients & Inference
result.params              # Coefficients
result.fittedvalues        # Predicted means
result.deviance            # Model deviance
result.bse()               # Standard errors
result.tvalues()           # z-statistics
result.pvalues()           # P-values
result.conf_int(alpha)     # Confidence intervals
result.significance_codes()# *, **, *** markers

# Robust Standard Errors (sandwich estimators)
result.bse_robust("HC1")   # Robust SE (HC0, HC1, HC2, HC3)
result.tvalues_robust()    # z-stats with robust SE
result.pvalues_robust()    # P-values with robust SE
result.conf_int_robust()   # Confidence intervals with robust SE
result.cov_robust()        # Full robust covariance matrix

# Diagnostics (statsmodels-compatible)
result.resid_response()    # Raw residuals (y - ฮผ)
result.resid_pearson()     # Pearson residuals
result.resid_deviance()    # Deviance residuals
result.resid_working()     # Working residuals
result.llf()               # Log-likelihood
result.aic()               # Akaike Information Criterion
result.bic()               # Bayesian Information Criterion
result.null_deviance()     # Null model deviance
result.pearson_chi2()      # Pearson chi-squared
result.scale()             # Dispersion (deviance-based)
result.scale_pearson()     # Dispersion (Pearson-based)
result.family              # Family name

Regularization

# Ridge (L2) - shrinks coefficients, keeps all variables
result = rs.glm("y ~ x1 + x2 + C(cat)", data, family="gaussian").fit(
    alpha=0.1, l1_ratio=0.0
)

# Lasso (L1) - variable selection, zeros out weak predictors
result = rs.glm("y ~ x1 + x2 + C(cat)", data, family="poisson").fit(
    alpha=0.1, l1_ratio=1.0
)
print(f"Selected {result.n_nonzero()} variables")
print(f"Features: {result.selected_features()}")

# Elastic Net - mix of L1 and L2
result = rs.glm("y ~ x1 + x2 + C(cat)", data, family="gaussian").fit(
    alpha=0.1, l1_ratio=0.5
)

Interaction Terms

# Continuous ร— Continuous interaction (main effects + interaction)
result = rs.glm(
    "ClaimNb ~ Age*VehPower",  # Equivalent to Age + VehPower + Age:VehPower
    data, family="poisson", offset="Exposure"
).fit()

# Categorical ร— Continuous interaction
result = rs.glm(
    "ClaimNb ~ C(Area)*Age",  # Each area level has different age effect
    data, family="poisson", offset="Exposure"
).fit()

# Categorical ร— Categorical interaction
result = rs.glm(
    "ClaimNb ~ C(Area)*C(VehBrand)",
    data, family="poisson", offset="Exposure"
).fit()

# Pure interaction (no main effects added)
result = rs.glm(
    "ClaimNb ~ Age + C(Area):VehPower",  # Area-specific VehPower slopes
    data, family="poisson", offset="Exposure"
).fit()

Spline Basis Functions

# Use splines in formulas - automatic parsing
result = rs.glm(
    "ClaimNb ~ bs(Age, df=5) + ns(VehPower, df=4) + C(Region)",
    data=data,
    family="poisson",
    offset="Exposure"
).fit()

# Combine splines with interactions
result = rs.glm(
    "y ~ bs(age, df=4)*C(gender) + ns(income, df=3)",
    data=data,
    family="gaussian"
).fit()

# Direct basis computation for custom use
import numpy as np
x = np.linspace(0, 10, 100)
basis = rs.bs(x, df=5)  # 5 degrees of freedom (4 basis columns)
basis_ns = rs.ns(x, df=5)  # Natural splines - linear extrapolation at boundaries

When to use each spline type:

  • B-splines (bs): Standard choice, more flexible at boundaries
  • Natural splines (ns): Better extrapolation, linear beyond boundaries (recommended for actuarial work)

Quasi-Families for Overdispersion

# Fit a standard Poisson model first
result_poisson = rs.glm("ClaimNb ~ Age + C(Region)", data, family="poisson", offset="Exposure").fit()

# Check for overdispersion: Pearson ฯ‡ยฒ / df >> 1 indicates overdispersion
dispersion_ratio = result_poisson.pearson_chi2() / result_poisson.df_resid
print(f"Dispersion ratio: {dispersion_ratio:.2f}")  # If >> 1, use quasi-family

# Fit QuasiPoisson if overdispersed
result_quasi = rs.glm("ClaimNb ~ Age + C(Region)", data, family="quasipoisson", offset="Exposure").fit()

# Coefficients are IDENTICAL to Poisson, but standard errors are inflated by โˆšฯ†
print(f"Estimated dispersion (ฯ†): {result_quasi.scale():.3f}")

# For binary data with overdispersion
result_qb = rs.glm("Binary ~ x1 + x2", data, family="quasibinomial").fit()

Key properties of quasi-families:

  • Point estimates: Identical to base family (Poisson/Binomial)
  • Standard errors: Inflated by โˆšฯ† where ฯ† = Pearson ฯ‡ยฒ/(n-p)
  • P-values: More conservative (larger), accounting for extra variance

Negative Binomial for Overdispersed Counts

# Automatic ฮธ estimation (default when theta not supplied)
result = rs.glm("ClaimNb ~ Age + C(Region)", data, family="negbinomial", offset="Exposure").fit()
print(result.family)  # "NegativeBinomial(theta=2.1234)"

# Fixed ฮธ value
result = rs.glm("ClaimNb ~ Age + C(Region)", data, family="negbinomial", theta=1.0, offset="Exposure").fit()

# ฮธ controls overdispersion: Var(Y) = ฮผ + ฮผยฒ/ฮธ
# - ฮธ=0.5: Strong overdispersion (variance = ฮผ + 2ฮผยฒ)
# - ฮธ=1.0: Moderate overdispersion (variance = ฮผ + ฮผยฒ)
# - ฮธโ†’โˆž: Approaches Poisson (variance = ฮผ)

NegativeBinomial vs QuasiPoisson:

Aspect QuasiPoisson NegativeBinomial
Variance ฯ† ร— ฮผ ฮผ + ฮผยฒ/ฮธ
True distribution No (quasi) Yes
AIC/BIC valid Questionable Yes
Prediction intervals Not principled Proper

Target Encoding for High-Cardinality Categoricals

# Formula API - TE() in formulas
result = rs.glm(
    "ClaimNb ~ TE(Brand) + TE(Model) + Age + C(Region)",
    data=data,
    family="poisson",
    offset="Exposure"
).fit()

# With options
result = rs.glm(
    "y ~ TE(brand, prior_weight=2.0, n_permutations=8) + age",
    data=data,
    family="gaussian"
).fit()

# Sklearn-style API
encoder = rs.TargetEncoder(prior_weight=1.0, n_permutations=4)
train_encoded = encoder.fit_transform(train_categories, train_target)
test_encoded = encoder.transform(test_categories)

Key benefits:

  • No target leakage: CatBoost-style ordered target statistics
  • Regularization: Prior weight controls shrinkage toward global mean
  • High-cardinality: Single column instead of thousands of dummies

Model Diagnostics

# Compute all diagnostics at once
diagnostics = result.diagnostics(
    data=data,
    categorical_factors=["Region", "VehBrand", "Area"],  # Including non-fitted
    continuous_factors=["Age", "Income", "VehPower"],    # Including non-fitted
)

# Export as compact JSON (optimized for LLM consumption)
json_str = diagnostics.to_json()

# Pre-fit data exploration (no model needed)
exploration = rs.explore_data(
    data=data,
    response="ClaimNb",
    categorical_factors=["Region", "VehBrand", "Area"],
    continuous_factors=["Age", "VehPower", "Income"],
    exposure="Exposure",
    family="poisson",
    detect_interactions=True,
)

Diagnostic Features:

  • Calibration: Overall A/E ratio, calibration by decile with CIs, Hosmer-Lemeshow test
  • Discrimination: Gini coefficient, AUC, KS statistic, lift metrics
  • Factor Diagnostics: A/E by level/bin for ALL factors (fitted and non-fitted)
  • Interaction Detection: Greedy residual-based detection of potential interactions
  • Warnings: Auto-generated alerts for high dispersion, poor calibration, missing factors

RustyStats vs Statsmodels

Feature RustyStats Statsmodels
Parallel IRLS Solver โœ… Multi-threaded via Rayon โŒ Single-threaded only
Native Polars Support โœ… Formula API works with Polars DataFrames โŒ Pandas only
Built-in Lasso/Elastic Net for GLMs โœ… Fast coordinate descent with all families โš ๏ธ Limited
Relativities Table โœ… result.relativities() for pricing โŒ Must compute manually
Robust Standard Errors โœ… HC0, HC1, HC2, HC3 sandwich estimators โœ… HC0-HC3

Performance Comparison (678,012 rows ร— 28 features)

Operation RustyStats Statsmodels
Poisson GLM ~1.0s ~5-10s
Ridge GLM ~1.0s ~5-10s
Lasso GLM ~2.8s Not available for GLMs

When to Use RustyStats

  • Large datasets - Parallel solver scales better
  • Regularized GLMs - Built-in Lasso/Ridge/Elastic Net for any family
  • Actuarial/Insurance - Relativities tables, Tweedie, exposure offsets
  • Polars workflows - Native Polars DataFrame support

When to Use Statsmodels

  • Broader model coverage - OLS, WLS, GLS, mixed effects, time series
  • Established ecosystem - More documentation, Stack Overflow answers

Project Structure

rustystats/
โ”œโ”€โ”€ Cargo.toml                    # Workspace config
โ”œโ”€โ”€ pyproject.toml                # Python package config
โ”‚
โ”œโ”€โ”€ crates/
โ”‚   โ”œโ”€โ”€ rustystats-core/          # Pure Rust GLM library
โ”‚   โ”‚   โ””โ”€โ”€ src/
โ”‚   โ”‚       โ”œโ”€โ”€ families/         # Gaussian, Poisson, Binomial, Gamma, Tweedie, Quasi, NegativeBinomial
โ”‚   โ”‚       โ”œโ”€โ”€ links/            # Identity, Log, Logit
โ”‚   โ”‚       โ”œโ”€โ”€ solvers/          # IRLS, coordinate descent
โ”‚   โ”‚       โ”œโ”€โ”€ inference/        # P-values, CIs, robust SE (HC0-HC3)
โ”‚   โ”‚       โ”œโ”€โ”€ interactions/     # Lazy interaction term computation
โ”‚   โ”‚       โ”œโ”€โ”€ splines/          # B-spline and natural spline basis functions
โ”‚   โ”‚       โ”œโ”€โ”€ design_matrix/    # Categorical encoding, interaction matrices
โ”‚   โ”‚       โ”œโ”€โ”€ formula/          # R-style formula parsing
โ”‚   โ”‚       โ”œโ”€โ”€ target_encoding/  # CatBoost-style ordered target statistics
โ”‚   โ”‚       โ””โ”€โ”€ diagnostics/      # Residuals, dispersion, AIC/BIC, calibration, loss
โ”‚   โ”‚
โ”‚   โ””โ”€โ”€ rustystats/               # Python bindings (PyO3)
โ”‚       โ””โ”€โ”€ src/lib.rs
โ”‚
โ”œโ”€โ”€ python/rustystats/            # Python package
โ”‚   โ”œโ”€โ”€ __init__.py               # Main exports
โ”‚   โ”œโ”€โ”€ formula.py                # Formula API with DataFrame support
โ”‚   โ”œโ”€โ”€ splines.py                # bs() and ns() spline basis functions
โ”‚   โ”œโ”€โ”€ target_encoding.py        # CatBoost-style target encoding
โ”‚   โ”œโ”€โ”€ diagnostics.py            # Model diagnostics with JSON export
โ”‚   โ””โ”€โ”€ families.py               # Family wrappers
โ”‚
โ”œโ”€โ”€ examples/
โ”‚   โ””โ”€โ”€ frequency.ipynb           # Claim frequency example
โ”‚
โ””โ”€โ”€ tests/python/                 # Python test suite

Dependencies

Rust

  • ndarray, nalgebra - Linear algebra
  • rayon - Parallel iterators (multi-threading)
  • statrs - Statistical distributions
  • pyo3 - Python bindings

Python

  • numpy - Array operations (required)
  • polars - DataFrame support (required)

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rustystats-0.1.0.tar.gz (165.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rustystats-0.1.0-cp313-cp313-manylinux_2_34_x86_64.whl (777.3 kB view details)

Uploaded CPython 3.13manylinux: glibc 2.34+ x86-64

File details

Details for the file rustystats-0.1.0.tar.gz.

File metadata

  • Download URL: rustystats-0.1.0.tar.gz
  • Upload date:
  • Size: 165.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.10.2

File hashes

Hashes for rustystats-0.1.0.tar.gz
Algorithm Hash digest
SHA256 e2838d9d414fa1180c2e2a1de710db1ede94d4579f675060ad0082c77240db24
MD5 2aa2723b9bde6d0918c90a6fdf371e57
BLAKE2b-256 d8f72acaff355bdef1a71c3154f17c80139f6d10907be82683b43e12522c1ff9

See more details on using hashes here.

File details

Details for the file rustystats-0.1.0-cp313-cp313-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for rustystats-0.1.0-cp313-cp313-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 7c0d7dbee42f988afd17896e8d227dbf0aaf4297069a369059c0faf394d8c8b0
MD5 1cbd43157b6d27a5fb4c22cd7716de6e
BLAKE2b-256 d04b7a036839ac024043082b97c9cd4da9d2162432bb65c60f01d1239c22069f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page