Skip to main content

Polars Least Squares Extension

Project description

Polars OLS

Least squares extension in Polars

Supports linear model estimation in Polars.

This package provides efficient rust implementations of common linear regression variants (OLS, WLS, Ridge, Elastic Net, Non-negative least squares, Recursive least squares) and exposes them as simple polars expressions which can easily be integrated into your workflow.

Why?

  1. High Performance: implementations are written in rust and make use of optimized rust linear-algebra crates & LAPACK routines. See benchmark section.
  2. Polars Integration: avoids unnecessary conversions from lazy to eager mode and to external libraries (e.g. numpy, sklearn) to do simple linear regressions. Chain least squares formulae like any other expression in polars.
  3. Efficient Implementations:
    • Numerically stable algorithms are chosen where appropriate (e.g. QR, Cholesky).
    • Flexible model specification allows arbitrary combination of sample weighting, L1/L2 regularization, & non-negativity constraints on parameters.
    • Efficient rank-1 update algorithms used for moving window regressions.
  4. Easy Parallelism: Computing OLS predictions, in parallel, across groups can not be easier: call .over() or group_by just like any other polars' expression and benefit from full Rust parallelism.
  5. Formula API: supports building models via patsy syntax: y ~ x1 + x2 + x3:x4 -1 (like statsmodels) which automatically converts to equivalent polars expressions.

Installation

First, you need to install Polars. Then run the below to install the polars-ols extension:

pip install polars-ols

API & Examples

Importing polars_ols will register the namespace least_squares provided by this package. You can build models either by either specifying polars expressions (e.g. pl.col(...)) for your targets and features or using the formula api (patsy syntax). All models support the following general (optional) arguments:

  • mode - a literal which determines the type of output produced by the model
  • null_policy - a literal which determines how to deal with missing data
  • add_intercept - a boolean specifying if an intercept feature should be added to the features
  • sample_weights - a column or expression providing non-negative weights applied to the samples

Remaining parameters are model specific, for example alpha penalty parameter used by regularized least squares models.

See below for basic usage examples. Please refer to the tests or demo notebook for detailed examples.

import polars as pl
import polars_ols as pls  # registers 'least_squares' namespace

df = pl.DataFrame({"y": [1.16, -2.16, -1.57, 0.21, 0.22, 1.6, -2.11, -2.92, -0.86, 0.47],
                   "x1": [0.72, -2.43, -0.63, 0.05, -0.07, 0.65, -0.02, -1.64, -0.92, -0.27],
                   "x2": [0.24, 0.18, -0.95, 0.23, 0.44, 1.01, -2.08, -1.36, 0.01, 0.75],
                   "group": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
                   "weights": [0.34, 0.97, 0.39, 0.8, 0.57, 0.41, 0.19, 0.87, 0.06, 0.34],
                   })

lasso_expr = pl.col("y").least_squares.lasso("x1", "x2", alpha=0.0001, add_intercept=True).over("group")
wls_expr = pls.compute_least_squares_from_formula("y ~ x1 + x2 -1", sample_weights=pl.col("weights"))

predictions = df.with_columns(lasso_expr.round(2).alias("predictions_lasso"),
                              wls_expr.round(2).alias("predictions_wls"))

print(predictions.head(5))
shape: (5, 7)
┌───────┬───────┬───────┬───────┬─────────┬───────────────────┬─────────────────┐
│ y     ┆ x1    ┆ x2    ┆ group ┆ weights ┆ predictions_lasso ┆ predictions_wls │
│ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---     ┆ ---               ┆ ---             │
│ f64   ┆ f64   ┆ f64   ┆ i64   ┆ f64     ┆ f64               ┆ f64             │
╞═══════╪═══════╪═══════╪═══════╪═════════╪═══════════════════╪═════════════════╡
│ 1.16  ┆ 0.72  ┆ 0.24  ┆ 1     ┆ 0.34    ┆ 0.97              ┆ 0.93            │
│ -2.16 ┆ -2.43 ┆ 0.18  ┆ 1     ┆ 0.97    ┆ -2.23             ┆ -2.18           │
│ -1.57 ┆ -0.63 ┆ -0.95 ┆ 1     ┆ 0.39    ┆ -1.54             ┆ -1.54           │
│ 0.21  ┆ 0.05  ┆ 0.23  ┆ 1     ┆ 0.8     ┆ 0.29              ┆ 0.27            │
│ 0.22  ┆ -0.07 ┆ 0.44  ┆ 1     ┆ 0.57    ┆ 0.37              ┆ 0.36            │
└───────┴───────┴───────┴───────┴─────────┴───────────────────┴─────────────────┘

The mode parameter is used to set the type of the output returned by all methods ("predictions", "residuals", "coefficients"). It defaults to returning predictions matching the input's length.

In case "coefficients" is set the output is a polars Struct with coefficients as values and feature names as fields. It's output shape 'broadcasts' depending on context, see below:

coefficients = df.select(pl.col("y").least_squares.from_formula("x1 + x2", mode="coefficients")
                         .alias("coefficients"))

coefficients_group = df.select("group", pl.col("y").least_squares.from_formula("x1 + x2", mode="coefficients").over("group")
                        .alias("coefficients_group")).unique(maintain_order=True)

print(coefficients)
print(coefficients_group)
shape: (1, 1)
┌──────────────────────────────┐
│ coefficients                 │
│ ---                          │
│ struct[3]                    │
╞══════════════════════════════╡
│ {0.977375,0.987413,0.000757} │  # <--- coef for x1, x2, and intercept added by formula API
└──────────────────────────────┘
shape: (2, 2)
┌───────┬───────────────────────────────┐
│ group ┆ coefficients_group            │
│ ---   ┆ ---                           │
│ i64   ┆ struct[3]                     │
╞═══════╪═══════════════════════════════╡
│ 1     ┆ {0.995157,0.977495,0.014344}  │
│ 2     ┆ {0.939217,0.997441,-0.017599} │  # <--- (unique) coefficients per group
└───────┴───────────────────────────────┘

For dynamic models (like rolling_ols) or if in a .over, .group_by, or .with_columns context, the coefficients will take the shape of the data it is applied on. For example:

coefficients = df.with_columns(pl.col("y").least_squares.rls(pl.col("x1"), pl.col("x2"), mode="coefficients")
                         .over("group").alias("coefficients"))

print(coefficients.head())
shape: (5, 6)
┌───────┬───────┬───────┬───────┬─────────┬─────────────────────┐
│ y     ┆ x1    ┆ x2    ┆ group ┆ weights ┆ coefficients        │
│ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---     ┆ ---                 │
│ f64   ┆ f64   ┆ f64   ┆ i64   ┆ f64     ┆ struct[2]           │
╞═══════╪═══════╪═══════╪═══════╪═════════╪═════════════════════╡
│ 1.16  ┆ 0.72  ┆ 0.24  ┆ 1     ┆ 0.34    ┆ {1.235503,0.411834} │
│ -2.16 ┆ -2.43 ┆ 0.18  ┆ 1     ┆ 0.97    ┆ {0.963515,0.760769} │
│ -1.57 ┆ -0.63 ┆ -0.95 ┆ 1     ┆ 0.39    ┆ {0.975484,0.966029} │
│ 0.21  ┆ 0.05  ┆ 0.23  ┆ 1     ┆ 0.8     ┆ {0.975657,0.953735} │
│ 0.22  ┆ -0.07 ┆ 0.44  ┆ 1     ┆ 0.57    ┆ {0.97898,0.909793}  │
└───────┴───────┴───────┴───────┴─────────┴─────────────────────┘

Finally, for convenience, in order to compute out-of-sample predictions you can use: least_squares.{predict, predict_from_formula}. This saves you the effort of un-nesting the coefficients and doing the dot product in python and instead does this in Rust, as an expression. Usage is as follows:

df_test.select(pl.col("coefficients_train").least_squares.predict(pl.col("x1"), pl.col("x2")).alias("predictions_test"))

Supported Models

Currently, this extension package supports the following variants:

  • Ordinary Least Squares: least_squares.ols
  • Weighted Least Squares: least_squares.wls
  • Regularized Least Squares (Lasso / Ridge / Elastic Net) least_squares.{lasso, ridge, elastic_net}
  • Non-negative Least Squares: least_squares.nnls
  • Multi-target Least Squares: least_squares.multi_target_ols

As well as efficient implementations of moving window models:

  • Recursive Least Squares: least_squares.rls
  • Rolling / Expanding Window OLS: least_squares.{rolling_ols, expanding_ols}

An arbitrary combination of sample_weights, L1/L2 penalties, and non-negativity constraints can be specified with the least_squares.from_formula and least_squares.least_squares entry-points.

Solve Methods

polars-ols provides a choice over multiple supported numerical approaches per model (via solve_method flag), with implications on performance vs numerical accuracy. These choices are exposed to the user for full control, however, if left unspecified the package will choose a reasonable default depending on context.

For example, if you know you are dealing with highly collinear data, with unregularized OLS model, you may want to explicitly set solve_method="svd" so that the minimum norm solution is obtained.

Benchmark

The usual caveats of benchmarks apply here, but the below should still be indicative of the type of performance improvements to expect when using this package.

This benchmark was run on randomly generated data with pyperf on my Apple M2 Max macbook (32GB RAM, MacOS Sonoma 14.2.1). See benchmark.py for implementation.

n_samples=2_000, n_features=5

Model polars_ols Python Benchmark Benchmark Type Speed-up vs Python Benchmark
Least Squares (QR) 195 µs ± 6 µs 466 µs ± 104 µs Numpy (QR) 2.4x
Least Squares (SVD) 247 µs ± 5 µs 395 µs ± 69 µs Numpy (SVD) 1.6x
Ridge (Cholesky) 171 µs ± 8 µs 1.02 ms ± 0.29 ms Sklearn (Cholesky) 5.9x
Ridge (SVD) 238 µs ± 7 µs 1.12 ms ± 0.41 ms Sklearn (SVD) 4.7x
Weighted Least Squares 334 µs ± 13 µs 2.04 ms ± 0.22 ms Statsmodels 6.1x
Elastic Net (CD) 227 µs ± 7 µs 1.18 ms ± 0.19 ms Sklearn 5.2x
Recursive Least Squares 1.12 ms ± 0.23 ms 18.2 ms ± 1.6 ms Statsmodels 16.2x
Rolling Least Squares 1.99 ms ± 0.03 ms 22.1 ms ± 0.2 ms Statsmodels 11.1x

n_samples=10_000, n_features=100

Model polars_ols Python Benchmark Benchmark Type Speed-up vs Python Benchmark
Least Squares (QR) 17.6 ms ± 0.3 ms 44.4 ms ± 9.3 ms Numpy (QR) 2.5x
Least Squares (SVD) 23.8 ms ± 0.2 ms 26.6 ms ± 5.5 ms Numpy (SVD) 1.1x
Ridge (Cholesky) 5.36 ms ± 0.16 ms 475 ms ± 71 ms Sklearn (Cholesky) 88.7x
Ridge (SVD) 30.2 ms ± 0.4 ms 400 ms ± 48 ms Sklearn (SVD) 13.2x
Weighted Least Squares 18.8 ms ± 0.3 ms 80.4 ms ± 12.4 ms Statsmodels 4.3x
Elastic Net (CD) 22.7 ms ± 0.2 ms 138 ms ± 27 ms Sklearn 6.1x
Recursive Least Squares 270 ms ± 53 ms 57.8 sec ± 43.7 sec Statsmodels 1017.0x
Rolling Least Squares 371 ms ± 13 ms 4.41 sec ± 0.17 sec Statsmodels 11.9x
  • Numpy's lstsq (uses divide-and-conquer SVD) is already a highly optimized call into LAPACK and so the scope for speed-up is relatively limited, and the same applies to simple approaches like directly solving normal equations with Cholesky.
  • However, even in such problems polars-ols Rust implementations for matching numerical algorithms tend to outperform by ~2-3x
  • More substantial speed-up is achieved for the more complex models by working entirely in rust and avoiding overhead from back and forth into python.
  • Expect a large additional relative order-of-magnitude speed up to your workflow if it involved repeated re-estimation of models in (python) loops.

Credits & Related Projects

  • Rust linear algebra libraries faer and ndarray support the implementations provided by this extension package
  • This package was templated around the very helpful: polars-plugin-tutorial
  • The python package patsy is used for (optionally) building models from formulae
  • Please check out the extension package polars-ds for general data-science functionality in polars

Future Work / TODOs

  • Support generic types, in rust implementations, so that both f32 and f64 types are recognized. Right now data is cast to f64 prior to estimation
  • Add docs explaining supported models, signatures, and API

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

polars_ols-0.3.3.tar.gz (80.7 kB view details)

Uploaded Source

Built Distributions

polars_ols-0.3.3-cp38-abi3-win_amd64.whl (10.2 MB view details)

Uploaded CPython 3.8+ Windows x86-64

polars_ols-0.3.3-cp38-abi3-win32.whl (8.7 MB view details)

Uploaded CPython 3.8+ Windows x86

polars_ols-0.3.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.5 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ x86-64

polars_ols-0.3.3-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl (13.3 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ i686

polars_ols-0.3.3-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl (12.1 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ ARMv7l

polars_ols-0.3.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (12.0 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ ARM64

polars_ols-0.3.3-cp38-abi3-macosx_11_0_arm64.whl (10.1 MB view details)

Uploaded CPython 3.8+ macOS 11.0+ ARM64

polars_ols-0.3.3-cp38-abi3-macosx_10_12_x86_64.whl (11.1 MB view details)

Uploaded CPython 3.8+ macOS 10.12+ x86-64

File details

Details for the file polars_ols-0.3.3.tar.gz.

File metadata

  • Download URL: polars_ols-0.3.3.tar.gz
  • Upload date:
  • Size: 80.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: maturin/1.5.1

File hashes

Hashes for polars_ols-0.3.3.tar.gz
Algorithm Hash digest
SHA256 c8674658a0da7864087c08d5d5276556227813c13f053293b6f20ace510dc368
MD5 9650246dbaf140b58905e63369060308
BLAKE2b-256 c293014dffa62a3123c73f98a8e0eddab685140d8ebce88d7539c68ee6a1757a

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.3-cp38-abi3-win_amd64.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.3-cp38-abi3-win_amd64.whl
Algorithm Hash digest
SHA256 a7a8b1e17abb77a5a02b1e740695c07745bf442cbf43e1b81c280329c58b3ab3
MD5 4b89f7c70906b6ac64ce954e5366371d
BLAKE2b-256 2c2c5d5b2777d8c9894ef2df44b91b480cddc04382004f6579317271619f57dc

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.3-cp38-abi3-win32.whl.

File metadata

  • Download URL: polars_ols-0.3.3-cp38-abi3-win32.whl
  • Upload date:
  • Size: 8.7 MB
  • Tags: CPython 3.8+, Windows x86
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: maturin/1.5.1

File hashes

Hashes for polars_ols-0.3.3-cp38-abi3-win32.whl
Algorithm Hash digest
SHA256 8085a4700f78c3b13abc6ca4a25ae644ae5ce78a395b539b1fdf0bd626595f00
MD5 4b9fd3e88eb9f2add2f63bb3a24e5b06
BLAKE2b-256 b61cc243b082f760518d28166365b6d7f4604300bba35b418f8420ba810b0e89

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 7a29a16c90e3a4444e94d925af32e9af650cf7cad89dba6809f503154ba1dd18
MD5 b23a397c48de64484cf650de79af0de2
BLAKE2b-256 1a3eea86f70e0557d4557dc6a526c503cb1b69de59244839283dfe3c50b316cc

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.3-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.3-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl
Algorithm Hash digest
SHA256 e183ddd4cb0c007ddbf5fd5ed4e12c3018db7ca8445e203cf1de953c7aad0738
MD5 32134a8e5daad94cb323ce2ae51dc05b
BLAKE2b-256 98bda666e324399823ab83f8d303f8f3c91ef3d6441b42b4c6134deb9ecd1ba1

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.3-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.3-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl
Algorithm Hash digest
SHA256 09cac32a4db10c39f87d16159398d0e2f4ee974153594ae40753b10ec9c1b310
MD5 b6d9d3e0ebc52e466249e2409fabf826
BLAKE2b-256 16f1d4b93a7233ab0759ae42db360be18d95fd049de179a182b2a27fb08a3c02

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 8f38195f89a36fc07ed1da73ecd5137e55394633ff77e3b7ebf21c79f11e9093
MD5 f80e4b543dc527cadd3f609a8c305cad
BLAKE2b-256 da1ccd04c521a81b29d20f328af5a43bce25fbf54f4918f7b23541ab11adfc74

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.3-cp38-abi3-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.3-cp38-abi3-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 a8b60e1478199cd31a3273947aee03db69e1e9f837456a32f2df602a7aaeac34
MD5 d48ed3024efae8f83e2397a71c3fa1d3
BLAKE2b-256 a6005a330b79e74d63c36bac7917f311c71a6cf63761d65315a50d7a0cedb4a6

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.3-cp38-abi3-macosx_10_12_x86_64.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.3-cp38-abi3-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 42f4b9597266604ea01b735189d4c11a979f616c91e07b764698bc751e85c01c
MD5 857c70fe3e212f8c851471f9e0fe4a4b
BLAKE2b-256 36ce13b54c2c618010498bc3919e31983b4ca2f08c89157e056b87b9d7b20523

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page