Fast Generalized Linear Models with a Rust backend - statsmodels compatible
Project description
RustyStats ๐ฆ๐
High-performance Generalized Linear Models with a Rust backend and Python API
Codebase Documentation: pricingfrontier.github.io/rustystats/
Performance Benchmarks
RustyStats vs Statsmodels โ Synthetic data, 101 features (10 continuous + 10 categorical with 10 levels each).
| Family | 10K rows | 250K rows | 500K rows |
|---|---|---|---|
| Gaussian | 15.6x | 5.7x | 4.3x |
| Poisson | 16.3x | 6.2x | 4.2x |
| Binomial | 19.5x | 6.8x | 4.4x |
| Gamma | 33.7x | 13.4x | 8.4x |
| NegBinomial | 26.7x | 6.7x | 5.0x |
Average speedup: 10.5x (range: 4.2x โ 33.7x)
Memory Usage
RustyStats uses significantly less RAM by reusing buffers and avoiding Python object overhead:
| Rows | RustyStats | Statsmodels | Reduction |
|---|---|---|---|
| 10K | 38 MB | 72 MB | 1.9x |
| 250K | 460 MB | 1,796 MB | 3.9x |
| 500K | 836 MB | 3,590 MB | 4.3x |
Memory advantage grows with data size โ at 500K rows, RustyStats uses ~4x less RAM.
Full benchmark details
| Family | Rows | RustyStats | Statsmodels | Speedup |
|---|---|---|---|---|
| Gaussian | 10,000 | 0.100s | 1.559s | 15.6x |
| Gaussian | 250,000 | 1.991s | 11.363s | 5.7x |
| Gaussian | 500,000 | 4.023s | 17.386s | 4.3x |
| Poisson | 10,000 | 0.165s | 2.692s | 16.3x |
| Poisson | 250,000 | 2.429s | 15.072s | 6.2x |
| Poisson | 500,000 | 5.668s | 23.693s | 4.2x |
| Binomial | 10,000 | 0.112s | 2.189s | 19.5x |
| Binomial | 250,000 | 1.946s | 13.155s | 6.8x |
| Binomial | 500,000 | 4.708s | 20.862s | 4.4x |
| Gamma | 10,000 | 0.129s | 4.353s | 33.7x |
| Gamma | 250,000 | 2.385s | 31.885s | 13.4x |
| Gamma | 500,000 | 5.499s | 46.167s | 8.4x |
| NegBinomial | 10,000 | 0.119s | 3.177s | 26.7x |
| NegBinomial | 250,000 | 2.281s | 15.278s | 6.7x |
| NegBinomial | 500,000 | 4.821s | 24.331s | 5.0x |
Times are median of 3 runs. Benchmark scripts in benchmarks/.
Features
- Fast - Parallel Rust backend, 4-30x faster than statsmodels
- Memory Efficient - 4x less RAM than statsmodels at scale
- Stable - Step-halving IRLS, warm starts for robust convergence
- Splines - B-splines
bs()and natural splinesns()in formulas - Polynomials - Identity terms
I(x ** 2)for polynomial and arithmetic expressions - Target Encoding - CatBoost-style
TE()for high-cardinality categoricals (exposure-aware) - Regularisation - Ridge, Lasso, and Elastic Net via coordinate descent
- Validation - Design matrix checks with fix suggestions before fitting
- Complete - 8 families, robust SEs, full diagnostics, VIF, partial dependence
- Minimal - Only
numpyandpolarsrequired
Installation
uv add rustystats
Quick Start
import rustystats as rs
import polars as pl
# Load data
data = pl.read_parquet("insurance.parquet")
# Fit a Poisson GLM for claim frequency
result = rs.glm(
"ClaimCount ~ VehAge + VehPower + C(Area) + C(Region)",
data=data,
family="poisson",
offset="Exposure"
).fit()
# View results
print(result.summary())
Families & Links
| Family | Default Link | Use Case |
|---|---|---|
gaussian |
identity | Linear regression |
poisson |
log | Claim frequency |
binomial |
logit | Binary outcomes |
gamma |
log | Claim severity |
tweedie |
log | Pure premium (var_power=1.5) |
quasipoisson |
log | Overdispersed counts |
quasibinomial |
logit | Overdispersed binary |
negbinomial |
log | Overdispersed counts (proper distribution) |
Formula Syntax
# Main effects
"y ~ x1 + x2 + C(category)"
# Single-level categorical indicators
"y ~ C(Region, level='Paris')" # 0/1 indicator for Paris only
"y ~ C(Region, levels=['Paris', 'Lyon'])" # Indicators for specific levels
# Interactions
"y ~ x1*x2" # x1 + x2 + x1:x2
"y ~ C(area):age" # Area-specific age effects
"y ~ C(area)*C(brand)" # Categorical ร categorical
# Splines (non-linear effects)
"y ~ bs(age, df=5)" # B-spline basis
"y ~ ns(income, df=4)" # Natural spline (better extrapolation)
# Identity terms (polynomial/arithmetic expressions)
"y ~ I(age ** 2)" # Polynomial terms
"y ~ I(x1 * x2)" # Explicit products
"y ~ I(income / 1000)" # Scaled variables
# Target encoding (high-cardinality categoricals)
"y ~ TE(brand) + TE(model)"
# Combined
"y ~ bs(age, df=5) + C(region)*income + ns(vehicle_age, df=3) + TE(brand) + I(age ** 2)"
Results Methods
# Coefficients & Inference
result.params # Coefficients
result.fittedvalues # Predicted means
result.deviance # Model deviance
result.bse() # Standard errors
result.tvalues() # z-statistics
result.pvalues() # P-values
result.conf_int(alpha) # Confidence intervals
# Robust Standard Errors (sandwich estimators)
result.bse_robust("HC1") # Robust SE (HC0, HC1, HC2, HC3)
result.tvalues_robust() # z-stats with robust SE
result.pvalues_robust() # P-values with robust SE
result.conf_int_robust() # Confidence intervals with robust SE
result.cov_robust() # Full robust covariance matrix
# Diagnostics (statsmodels-compatible)
result.resid_response() # Raw residuals (y - ฮผ)
result.resid_pearson() # Pearson residuals
result.resid_deviance() # Deviance residuals
result.resid_working() # Working residuals
result.llf() # Log-likelihood
result.aic() # Akaike Information Criterion
result.bic() # Bayesian Information Criterion
result.null_deviance() # Null model deviance
result.pearson_chi2() # Pearson chi-squared
result.scale() # Dispersion (deviance-based)
result.scale_pearson() # Dispersion (Pearson-based)
result.family # Family name
Regularization
# Ridge (L2) - shrinks coefficients, keeps all variables
result = rs.glm("y ~ x1 + x2 + C(cat)", data, family="gaussian").fit(
alpha=0.1, l1_ratio=0.0
)
# Lasso (L1) - variable selection, zeros out weak predictors
result = rs.glm("y ~ x1 + x2 + C(cat)", data, family="poisson").fit(
alpha=0.1, l1_ratio=1.0
)
print(f"Selected {result.n_nonzero()} variables")
print(f"Features: {result.selected_features()}")
# Elastic Net - mix of L1 and L2
result = rs.glm("y ~ x1 + x2 + C(cat)", data, family="gaussian").fit(
alpha=0.1, l1_ratio=0.5
)
Interaction Terms
# Continuous ร Continuous interaction (main effects + interaction)
result = rs.glm(
"ClaimNb ~ Age*VehPower", # Equivalent to Age + VehPower + Age:VehPower
data, family="poisson", offset="Exposure"
).fit()
# Categorical ร Continuous interaction
result = rs.glm(
"ClaimNb ~ C(Area)*Age", # Each area level has different age effect
data, family="poisson", offset="Exposure"
).fit()
# Categorical ร Categorical interaction
result = rs.glm(
"ClaimNb ~ C(Area)*C(VehBrand)",
data, family="poisson", offset="Exposure"
).fit()
# Pure interaction (no main effects added)
result = rs.glm(
"ClaimNb ~ Age + C(Area):VehPower", # Area-specific VehPower slopes
data, family="poisson", offset="Exposure"
).fit()
Spline Basis Functions
# Use splines in formulas - automatic parsing
result = rs.glm(
"ClaimNb ~ bs(Age, df=5) + ns(VehPower, df=4) + C(Region)",
data=data,
family="poisson",
offset="Exposure"
).fit()
# Combine splines with interactions
result = rs.glm(
"y ~ bs(age, df=4)*C(gender) + ns(income, df=3)",
data=data,
family="gaussian"
).fit()
# Direct basis computation for custom use
import numpy as np
x = np.linspace(0, 10, 100)
basis = rs.bs(x, df=5) # 5 degrees of freedom (4 basis columns)
basis_ns = rs.ns(x, df=5) # Natural splines - linear extrapolation at boundaries
When to use each spline type:
- B-splines (
bs): Standard choice, more flexible at boundaries - Natural splines (
ns): Better extrapolation, linear beyond boundaries (recommended for actuarial work)
Quasi-Families for Overdispersion
# Fit a standard Poisson model first
result_poisson = rs.glm("ClaimNb ~ Age + C(Region)", data, family="poisson", offset="Exposure").fit()
# Check for overdispersion: Pearson ฯยฒ / df >> 1 indicates overdispersion
dispersion_ratio = result_poisson.pearson_chi2() / result_poisson.df_resid
print(f"Dispersion ratio: {dispersion_ratio:.2f}") # If >> 1, use quasi-family
# Fit QuasiPoisson if overdispersed
result_quasi = rs.glm("ClaimNb ~ Age + C(Region)", data, family="quasipoisson", offset="Exposure").fit()
# Coefficients are IDENTICAL to Poisson, but standard errors are inflated by โฯ
print(f"Estimated dispersion (ฯ): {result_quasi.scale():.3f}")
# For binary data with overdispersion
result_qb = rs.glm("Binary ~ x1 + x2", data, family="quasibinomial").fit()
Key properties of quasi-families:
- Point estimates: Identical to base family (Poisson/Binomial)
- Standard errors: Inflated by โฯ where ฯ = Pearson ฯยฒ/(n-p)
- P-values: More conservative (larger), accounting for extra variance
Negative Binomial for Overdispersed Counts
# Automatic ฮธ estimation (default when theta not supplied)
result = rs.glm("ClaimNb ~ Age + C(Region)", data, family="negbinomial", offset="Exposure").fit()
print(result.family) # "NegativeBinomial(theta=2.1234)"
# Fixed ฮธ value
result = rs.glm("ClaimNb ~ Age + C(Region)", data, family="negbinomial", theta=1.0, offset="Exposure").fit()
# ฮธ controls overdispersion: Var(Y) = ฮผ + ฮผยฒ/ฮธ
# - ฮธ=0.5: Strong overdispersion (variance = ฮผ + 2ฮผยฒ)
# - ฮธ=1.0: Moderate overdispersion (variance = ฮผ + ฮผยฒ)
# - ฮธโโ: Approaches Poisson (variance = ฮผ)
NegativeBinomial vs QuasiPoisson:
| Aspect | QuasiPoisson | NegativeBinomial |
|---|---|---|
| Variance | ฯ ร ฮผ | ฮผ + ฮผยฒ/ฮธ |
| True distribution | No (quasi) | Yes |
| AIC/BIC valid | Questionable | Yes |
| Prediction intervals | Not principled | Proper |
Target Encoding for High-Cardinality Categoricals
# Formula API - TE() in formulas
result = rs.glm(
"ClaimNb ~ TE(Brand) + TE(Model) + Age + C(Region)",
data=data,
family="poisson",
offset="Exposure"
).fit()
# With options
result = rs.glm(
"y ~ TE(brand, prior_weight=2.0, n_permutations=8) + age",
data=data,
family="gaussian"
).fit()
# Sklearn-style API
encoder = rs.TargetEncoder(prior_weight=1.0, n_permutations=4)
train_encoded = encoder.fit_transform(train_categories, train_target)
test_encoded = encoder.transform(test_categories)
Key benefits:
- No target leakage: Ordered target statistics
- Regularization: Prior weight controls shrinkage toward global mean
- High-cardinality: Single column instead of thousands of dummies
- Exposure-aware: For frequency models with
offset="Exposure", TE() automatically uses claim rate (ClaimCount/Exposure) instead of raw counts, preventing near-constant encoded values
Identity Terms for Polynomials
# Polynomial terms
result = rs.glm(
"y ~ age + I(age ** 2) + I(age ** 3)",
data=data,
family="gaussian"
).fit()
# Arithmetic expressions
result = rs.glm(
"y ~ I(income / 1000) + I(weight * height)",
data=data,
family="gaussian"
).fit()
Supported operations: +, -, *, /, ** (power)
Design Matrix Validation
# Check for issues before fitting
model = rs.glm("y ~ ns(x, df=4) + C(cat)", data, family="poisson")
results = model.validate() # Prints diagnostics
if not results['valid']:
print("Issues:", results['suggestions'])
# Validation runs automatically on fit failure with helpful suggestions
Checks performed:
- Rank deficiency (linearly dependent columns)
- High multicollinearity (condition number)
- Zero variance columns
- NaN/Inf values
- Highly correlated column pairs (>0.999)
Model Diagnostics
# Compute all diagnostics at once
diagnostics = result.diagnostics(
data=data,
categorical_factors=["Region", "VehBrand", "Area"], # Including non-fitted
continuous_factors=["Age", "Income", "VehPower"], # Including non-fitted
)
# Export as compact JSON (optimized for LLM consumption)
json_str = diagnostics.to_json()
# Pre-fit data exploration (no model needed)
exploration = rs.explore_data(
data=data,
response="ClaimNb",
categorical_factors=["Region", "VehBrand", "Area"],
continuous_factors=["Age", "VehPower", "Income"],
exposure="Exposure",
family="poisson",
detect_interactions=True,
)
Diagnostic Features:
- Calibration: Overall A/E ratio, calibration by decile with CIs, Hosmer-Lemeshow test
- Discrimination: Gini coefficient, AUC, KS statistic, lift metrics
- Factor Diagnostics: A/E by level/bin for ALL factors (fitted and non-fitted)
- VIF/Multicollinearity: Variance inflation factors for design matrix columns
- Partial Dependence: Effect plots with shape detection and recommendations
- Overfitting Detection: Compare train vs test metrics when test data provided
- Interaction Detection: Greedy residual-based detection of potential interactions
- Warnings: Auto-generated alerts for high dispersion, poor calibration, missing factors
RustyStats vs Statsmodels
| Feature | RustyStats | Statsmodels |
|---|---|---|
| Parallel IRLS Solver | โ Multi-threaded | โ Single-threaded only |
| Native Polars Support | โ Polars only | โ Pandas only |
| Built-in Lasso/Elastic Net for GLMs | โ Fast coordinate descent with all families | โ ๏ธ Limited |
| Relativities Table | โ
result.relativities() for pricing |
โ Must compute manually |
| Robust Standard Errors | โ HC0, HC1, HC2, HC3 sandwich estimators | โ HC0-HC3 |
Project Structure
rustystats/
โโโ Cargo.toml # Workspace config
โโโ pyproject.toml # Python package config
โ
โโโ crates/
โ โโโ rustystats-core/ # Pure Rust GLM library
โ โ โโโ src/
โ โ โโโ families/ # Gaussian, Poisson, Binomial, Gamma, Tweedie, Quasi, NegativeBinomial
โ โ โโโ links/ # Identity, Log, Logit
โ โ โโโ solvers/ # IRLS, coordinate descent
โ โ โโโ inference/ # P-values, CIs, robust SE (HC0-HC3)
โ โ โโโ interactions/ # Lazy interaction term computation
โ โ โโโ splines/ # B-spline and natural spline basis functions
โ โ โโโ design_matrix/ # Categorical encoding, interaction matrices
โ โ โโโ formula/ # R-style formula parsing
โ โ โโโ target_encoding/ # Ordered target statistics
โ โ โโโ diagnostics/ # Residuals, dispersion, AIC/BIC, calibration, loss
โ โ
โ โโโ rustystats/ # Python bindings (PyO3)
โ โโโ src/lib.rs
โ
โโโ python/rustystats/ # Python package
โ โโโ __init__.py # Main exports
โ โโโ formula.py # Formula API with DataFrame support
โ โโโ interactions.py # Interaction terms, I() expressions, design matrix
โ โโโ splines.py # bs() and ns() spline basis functions
โ โโโ target_encoding.py # Target encoding (exposure-aware)
โ โโโ diagnostics.py # Model diagnostics with JSON export
โ โโโ families.py # Family wrappers
โ
โโโ examples/
โ โโโ frequency.ipynb # Claim frequency example
โ
โโโ tests/python/ # Python test suite
Dependencies
Rust
ndarray,nalgebra- Linear algebrarayon- Parallel iterators (multi-threading)statrs- Statistical distributionspyo3- Python bindings
Python
numpy- Array operations (required)polars- DataFrame support (required)
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file rustystats-0.1.12.tar.gz.
File metadata
- Download URL: rustystats-0.1.12.tar.gz
- Upload date:
- Size: 194.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: maturin/1.10.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
176070042a40844ba283d9a0b61c033de78034b7ed8777f5149b1d8e0e068698
|
|
| MD5 |
1519b581624a0af309da9a74744aacbd
|
|
| BLAKE2b-256 |
8a574c02f7f232554fb638038fca3ce1efcf135b48339f620803938de1d49f01
|
File details
Details for the file rustystats-0.1.12-cp313-cp313-manylinux_2_34_x86_64.whl.
File metadata
- Download URL: rustystats-0.1.12-cp313-cp313-manylinux_2_34_x86_64.whl
- Upload date:
- Size: 839.0 kB
- Tags: CPython 3.13, manylinux: glibc 2.34+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: maturin/1.10.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1ccb4783fd448c71b97605f8168a5d8001cfc0a63a7a054442f3463b02acf208
|
|
| MD5 |
9809797f2914b079500e89da7853bb5c
|
|
| BLAKE2b-256 |
e25ac6d71ddb28d2c108f7ba1261d2f2dbfabbbd24460e1dfb139df68f32914a
|