Skip to main content

Polars Least Squares Extension

Project description

Polars OLS

Least squares extension in Polars

Supports linear model estimation in Polars.

This package provides efficient rust implementations of common linear regression variants (OLS, WLS, Ridge, Elastic Net, Non-negative least squares, Recursive least squares) and exposes them as simple polars expressions which can easily be integrated into your workflow.

Why?

  1. High Performance: implementations are written in rust and make use of optimized rust linear-algebra crates & LAPACK routines. See benchmark section.
  2. Polars Integration: avoids unnecessary conversions from lazy to eager mode and to external libraries (e.g. numpy, sklearn) to do simple linear regressions. Chain least squares formulae like any other expression in polars.
  3. Efficient Implementations:
    • Numerically stable algorithms are chosen where appropriate (e.g. QR, Cholesky).
    • Flexible model specification allows arbitrary combination of sample weighting, L1/L2 regularization, & non-negativity constraints on parameters.
    • Efficient rank-1 update algorithms used for moving window regressions.
  4. Easy Parallelism: Computing OLS predictions, in parallel, across groups can not be easier: call .over() or group_by just like any other polars' expression and benefit from full Rust parallelism.
  5. Formula API: supports building models via patsy syntax: y ~ x1 + x2 + x3:x4 -1 (like statsmodels) which automatically converts to equivalent polars expressions.

Installation

First, you need to install Polars. Then run the below to install the polars-ols extension:

pip install polars-ols

API & Examples

Importing polars_ols will register the namespace least_squares provided by this package. You can build models either by either specifying polars expressions (e.g. pl.col(...)) for your targets and features or using the formula api (patsy syntax). All models support the following general (optional) arguments:

  • mode - a literal which determines the type of output produced by the model
  • null_policy - a literal which determines how to deal with missing data
  • add_intercept - a boolean specifying if an intercept feature should be added to the features
  • sample_weights - a column or expression providing non-negative weights applied to the samples

Remaining parameters are model specific, for example alpha penalty parameter used by regularized least squares models.

See below for basic usage examples. Please refer to the tests or demo notebook for detailed examples.

import polars as pl
import polars_ols as pls  # registers 'least_squares' namespace

df = pl.DataFrame({"y": [1.16, -2.16, -1.57, 0.21, 0.22, 1.6, -2.11, -2.92, -0.86, 0.47],
                   "x1": [0.72, -2.43, -0.63, 0.05, -0.07, 0.65, -0.02, -1.64, -0.92, -0.27],
                   "x2": [0.24, 0.18, -0.95, 0.23, 0.44, 1.01, -2.08, -1.36, 0.01, 0.75],
                   "group": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
                   "weights": [0.34, 0.97, 0.39, 0.8, 0.57, 0.41, 0.19, 0.87, 0.06, 0.34],
                   })

lasso_expr = pl.col("y").least_squares.lasso(pl.col("x1"), pl.col("x2"), alpha=0.0001, add_intercept=True).over("group")
wls_expr = pls.compute_least_squares_from_formula("y ~ x1 + x2 -1", sample_weights=pl.col("weights"))

predictions = df.with_columns(lasso_expr.round(2).alias("predictions_lasso"),
                              wls_expr.round(2).alias("predictions_wls"))

print(predictions.head(5))
shape: (5, 7)
┌───────┬───────┬───────┬───────┬─────────┬───────────────────┬─────────────────┐
│ y     ┆ x1    ┆ x2    ┆ group ┆ weights ┆ predictions_lasso ┆ predictions_wls │
│ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---     ┆ ---               ┆ ---             │
│ f64   ┆ f64   ┆ f64   ┆ i64   ┆ f64     ┆ f32               ┆ f32             │
╞═══════╪═══════╪═══════╪═══════╪═════════╪═══════════════════╪═════════════════╡
│ 1.16  ┆ 0.72  ┆ 0.24  ┆ 1     ┆ 0.34    ┆ 0.97              ┆ 0.93            │
│ -2.16 ┆ -2.43 ┆ 0.18  ┆ 1     ┆ 0.97    ┆ -2.23             ┆ -2.18           │
│ -1.57 ┆ -0.63 ┆ -0.95 ┆ 1     ┆ 0.39    ┆ -1.54             ┆ -1.54           │
│ 0.21  ┆ 0.05  ┆ 0.23  ┆ 1     ┆ 0.8     ┆ 0.29              ┆ 0.27            │
│ 0.22  ┆ -0.07 ┆ 0.44  ┆ 1     ┆ 0.57    ┆ 0.37              ┆ 0.36            │
└───────┴───────┴───────┴───────┴─────────┴───────────────────┴─────────────────┘

The mode parameter is used to set the type of the output returned by all methods ("predictions", "residuals", "coefficients"). It defaults to returning predictions matching the input's length.

In case "coefficients" is set the output is a polars Struct with coefficients as values and feature names as fields. It's output shape 'broadcasts' depending on context, see below:

coefficients = df.select(pl.col("y").least_squares.from_formula("x1 + x2", mode="coefficients")
                         .alias("coefficients"))

coefficients_group = df.select("group", pl.col("y").least_squares.from_formula("x1 + x2", mode="coefficients").over("group")
                        .alias("coefficients_group")).unique(maintain_order=True)

print(coefficients)
print(coefficients_group)
shape: (1, 1)
┌──────────────────────────────┐
│ coefficients                 │
│ ---                          │
│ struct[3]                    │
╞══════════════════════════════╡
│ {0.977375,0.987413,0.000757} │  # <--- coef for x1, x2, and intercept added by formula API
└──────────────────────────────┘
shape: (2, 2)
┌───────┬───────────────────────────────┐
│ group ┆ coefficients_group            │
│ ---   ┆ ---                           │
│ i64   ┆ struct[3]                     │
╞═══════╪═══════════════════════════════╡
│ 1     ┆ {0.995157,0.977495,0.014344}  │
│ 2     ┆ {0.939217,0.997441,-0.017599} │  # <--- (unique) coefficients per group
└───────┴───────────────────────────────┘

For dynamic models (like rolling_ols) or if in a .over, .group_by, or .with_columns context, the coefficients will take the shape of the data it is applied on. For example:

coefficients = df.with_columns(pl.col("y").least_squares.rls(pl.col("x1"), pl.col("x2"), mode="coefficients")
                         .over("group").alias("coefficients"))

print(coefficients.head())
shape: (5, 6)
┌───────┬───────┬───────┬───────┬─────────┬─────────────────────┐
│ y     ┆ x1    ┆ x2    ┆ group ┆ weights ┆ coefficients        │
│ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---     ┆ ---                 │
│ f64   ┆ f64   ┆ f64   ┆ i64   ┆ f64     ┆ struct[2]           │
╞═══════╪═══════╪═══════╪═══════╪═════════╪═════════════════════╡
│ 1.16  ┆ 0.72  ┆ 0.24  ┆ 1     ┆ 0.34    ┆ {1.235503,0.411834} │
│ -2.16 ┆ -2.43 ┆ 0.18  ┆ 1     ┆ 0.97    ┆ {0.963515,0.760769} │
│ -1.57 ┆ -0.63 ┆ -0.95 ┆ 1     ┆ 0.39    ┆ {0.975484,0.966029} │
│ 0.21  ┆ 0.05  ┆ 0.23  ┆ 1     ┆ 0.8     ┆ {0.975657,0.953735} │
│ 0.22  ┆ -0.07 ┆ 0.44  ┆ 1     ┆ 0.57    ┆ {0.97898,0.909793}  │
└───────┴───────┴───────┴───────┴─────────┴─────────────────────┘

Finally, for convenience, in order to compute out-of-sample predictions you can use: least_squares.{predict, predict_from_formula}. This saves you the effort of un-nesting the coefficients and doing the dot product in python and instead does this in Rust, as an expression. Usage is as follows:

df_test.select(pl.col("coefficients_train").least_squares.predict(pl.col("x1"), pl.col("x2")).alias("predictions_test"))

Supported Models

Currently, this extension package supports the following variants:

  • Ordinary Least Squares: least_squares.ols
  • Weighted Least Squares: least_squares.wls
  • Regularized Least Squares (Lasso / Ridge / Elastic Net) least_squares.{lasso, ridge, elastic_net}
  • Non-negative Least Squares: least_squares.nnls

As well as efficient implementations of moving window models:

  • Recursive Least Squares: least_squares.rls
  • Rolling / Expanding Window OLS: least_squares.{rolling_ols, expanding_ols}

An arbitrary combination of sample_weights, L1/L2 penalties, and non-negativity constraints can be specified with the least_squares.from_formula and least_squares.least_squares entry-points.

Solve Methods

polars-ols provides a choice over multiple supported numerical approaches per model (via solve_method flag), with implications on performance vs numerical accuracy. These choices are exposed to the user for full control, however, if left unspecified the package will choose a reasonable default depending on context.

For example, if you know you are dealing with highly collinear data, with unregularized OLS model, you may want to explicitly set solve_method="svd" so that the minimum norm solution is obtained.

Benchmark

The usual caveats of benchmarks apply here, but the below should still be indicative of the type of performance improvements to expect when using this package.

This benchmark was run on randomly generated data with pyperf on my Apple M2 Max macbook (32GB RAM, MacOS Sonoma 14.2.1). See benchmark.py for implementation.

n_samples=2_000, n_features=5

Model polars_ols Python Benchmark Benchmark Type Speed-up vs Python Benchmark
Least Squares (QR) 300 us ± 7 us 1.01 ms ± 0.81 ms Numpy (QR) 3.4x
Least Squares (SVD) 351 us ± 4 us 853 us ± 417 us Numpy (SVD) 2.4x
Ridge (Cholesky) 279 us ± 6 us 1.63 ms ± 0.69 ms Sklearn (Cholesky) 5.8x
Ridge (SVD) 351 us ± 5 us 1.95 ms ± 1.12 ms Sklearn (SVD) 5.6x
Weighted Least Squares 531 us ± 4 us 2.54 ms ± 0.40 ms Statsmodels 4.8x
Elastic Net (CD) 339 us ± 5 us 2.17 ms ± 0.77 ms Sklearn 6.4x
Recursive Least Squares 1.42 ms ± 0.02 ms 18.5 ms ± 1.4 ms Statsmodels 13.0x
Rolling Least Squares 2.78 ms ± 0.07 ms 22.8 ms ± 0.2 ms Statsmodels 8.2x

n_samples=10_000, n_features=100

Model polars_ols Python Benchmark Benchmark Type Speed-up vs Python Benchmark
Least Squares (QR) 12.4 ms ± 0.2 ms 68.3 ms ± 13.7 ms Numpy (QR) 5.5x
Least Squares (SVD) 14.5 ms ± 0.5 ms 44.9 ms ± 10.3 ms Numpy (SVD) 3.1x
Ridge (Cholesky) 6.10 ms ± 0.14 ms 9.91 ms ± 2.86 ms Sklearn (Cholesky) 1.6x
Ridge (SVD) 24.9 ms ± 2.1 ms 390 ms ± 63 ms Sklearn (SVD) 15.7x
Weighted Least Squares 14.8 ms ± 2.4 ms 114 ms ± 35 ms Statsmodels 7.7x
Elastic Net (CD) 21.7 ms ± 1.2 ms 111 ms ± 54 ms Sklearn 5.1x
Recursive Least Squares 163 ms ± 28 ms 65.7 sec ± 28.2 sec Statsmodels 403.1x
Rolling Least Squares 390 ms ± 10 ms 3.99 sec ± 0.54 sec Statsmodels 10.2x
  • Numpy's lstsq (uses divide-and-conquer SVD) is already a highly optimized call into LAPACK and so the scope for speed-up is relatively limited, and the same applies to simple approaches like directly solving normal equations with Cholesky.
  • However, even in such problems polars-ols Rust implementations for matching numerical algorithms tend to outperform by ~2-3x
  • More substantial speed-up is achieved for the more complex models by working entirely in rust and avoiding overhead from back and forth into python.
  • Expect a large additional relative order-of-magnitude speed up to your workflow if it involved repeated re-estimation of models in (python) loops.

Credits & Related Projects

  • Rust linear algebra libraries faer and ndarray support the implementations provided by this extension package
  • This package was templated around the very helpful: polars-plugin-tutorial
  • The python package patsy is used for (optionally) building models from formulae
  • Please check out the extension package polars-ds for general data-science functionality in polars

Future Work / TODOs

  • Support generic types, in rust implementations, so that both f32 and f64 types are recognized. Right now data is cast to f32 prior to estimation
  • Add more detailed documentation on supported models, signatures, and API

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

polars_ols-0.2.5.tar.gz (60.0 kB view details)

Uploaded Source

Built Distributions

polars_ols-0.2.5-cp38-abi3-win_amd64.whl (4.2 MB view details)

Uploaded CPython 3.8+ Windows x86-64

polars_ols-0.2.5-cp38-abi3-win32.whl (3.6 MB view details)

Uploaded CPython 3.8+ Windows x86

polars_ols-0.2.5-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.9 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ x86-64

polars_ols-0.2.5-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl (8.1 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ i686

polars_ols-0.2.5-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl (7.2 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ ARMv7l

polars_ols-0.2.5-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (7.3 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ ARM64

polars_ols-0.2.5-cp38-abi3-macosx_11_0_arm64.whl (3.9 MB view details)

Uploaded CPython 3.8+ macOS 11.0+ ARM64

polars_ols-0.2.5-cp38-abi3-macosx_10_12_x86_64.whl (4.4 MB view details)

Uploaded CPython 3.8+ macOS 10.12+ x86-64

File details

Details for the file polars_ols-0.2.5.tar.gz.

File metadata

  • Download URL: polars_ols-0.2.5.tar.gz
  • Upload date:
  • Size: 60.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: maturin/1.5.1

File hashes

Hashes for polars_ols-0.2.5.tar.gz
Algorithm Hash digest
SHA256 02bd3a2628ebfede9a5cc53fff900accb8c8f30cab969fa76007fa97622d5f57
MD5 2c27d11e16b0775e03ad6e8ceb24f4d8
BLAKE2b-256 f6169a4e086ca5e49c950a4816cbe0db831427a9368a62ea22ac79fc98536269

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.5-cp38-abi3-win_amd64.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.5-cp38-abi3-win_amd64.whl
Algorithm Hash digest
SHA256 2118ae9ffa6f5c68a02e170e25fe130a98076263f16eb8ac1f67903e3063280e
MD5 b29f9db34078fbd7e487fea5aac93653
BLAKE2b-256 4fbc9584f55f75f476bab559c2a697cf32b303e6cd4398d17a2b9ab36497422a

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.5-cp38-abi3-win32.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.5-cp38-abi3-win32.whl
Algorithm Hash digest
SHA256 3d9ce8134be77984af9040f4decba46543c11bfb39e91c080c68d4e5b5cd61a1
MD5 0e1c0e035b68cdb362f1e8deaf439401
BLAKE2b-256 2b27b85fa3eaf017f6d25d05766cac7d3f5ed590c61b6f5d77c6c7ef8c21c6d7

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.5-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.5-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 00f7071d76b1642a57334775d18c97e260d0ce445e238b3693638e33f7613401
MD5 d0872dab69ab64dfc619f8fc36c09dd9
BLAKE2b-256 dd4dcdff2872c1ceb03d2d2f4e4f267020495e516f49e72be11cc9e169325a8c

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.5-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.5-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl
Algorithm Hash digest
SHA256 7110dc0fa9ebc06869acf07bd94cf6f55fd76f6813cc2b076a87fc645595a03a
MD5 8685fc10d8978e15d39910166d911cbc
BLAKE2b-256 e55bdf2e1f2cf9f68c51e78f8bee217499b0f90b96d37621840729dd1610b3c9

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.5-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.5-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl
Algorithm Hash digest
SHA256 0e1d8c5e126b036ff5d5c7a494b126f83028df2e447a71b27e1b6cfee74654b2
MD5 1c5ed1346648f0761e61c9d27c05a777
BLAKE2b-256 0e524f1e61921b22e11b94a5f6e5a34cb4391021167477c84b2f27c820dc5bec

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.5-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.5-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 d801c862e052738e20901f1a8b029a7433e78d84831581265a070bfbf0e29d0b
MD5 5bcda4fc7e6d3cfaa89674f229254f8c
BLAKE2b-256 03bb789e6cfe1ecebbda24a0b21f422db1587b95d92bb238bdb3293fa792d039

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.5-cp38-abi3-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.5-cp38-abi3-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 e0a5a2a68db89d582260a2d806cc341838ff029be1f280253802b517d54fc86b
MD5 38abfa332928b264c8e0fe3e7b219dfb
BLAKE2b-256 60b2493ad95dbc072fcd6d89d5191b323a01f69c7d59927168f73db71d6b73f8

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.5-cp38-abi3-macosx_10_12_x86_64.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.5-cp38-abi3-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 81bd52d4e72022ffe5cc09948393cccddebb9702b17c132c7a8a171f7e7db025
MD5 389a7b64be2a29dbeecf841973c33c80
BLAKE2b-256 c6899a21b2b0403782778ded5a57f0ae4aeb7ad7a162742a695da870315c01fd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page