Skip to main content

Polars Least Squares Extension

Project description

Polars OLS

Least squares extension in Polars

Supports linear model estimation in Polars.

This package provides efficient rust implementations of common linear regression variants (OLS, WLS, Ridge, Elastic Net, Non-negative least squares, Recursive least squares) and exposes them as simple polars expressions which can easily be integrated into your workflow.

Why?

  1. High Performance: implementations are written in rust and make use of optimized rust linear-algebra crates & LAPACK routines. See benchmark section.
  2. Polars Integration: avoids unnecessary conversions from lazy to eager mode and to external libraries (e.g. numpy, sklearn) to do simple linear regressions. Chain least squares formulae like any other expression in polars.
  3. Efficient Implementations:
    • Numerically stable algorithms are chosen where appropriate (e.g. QR, Cholesky).
    • Flexible model specification allows arbitrary combination of sample weighting, L1/L2 regularization, & non-negativity constraints on parameters.
    • Efficient rank-1 update algorithms used for moving window regressions.
  4. Easy Parallelism: Computing OLS predictions, in parallel, across groups can not be easier: call .over() or group_by just like any other polars' expression and benefit from full Rust parallelism.
  5. Formula API: supports building models via patsy syntax: y ~ x1 + x2 + x3:x4 -1 (like statsmodels) which automatically converts to equivalent polars expressions.

Installation

First, you need to install Polars. Then run the below to install the polars-ols extension:

pip install polars-ols

API & Examples

Importing polars_ols will register the namespace least_squares provided by this package. You can build models either by either specifying polars expressions (e.g. pl.col(...)) for your targets and features or using the formula api (patsy syntax). All models support the following general (optional) arguments:

  • mode - a literal which determines the type of output produced by the model
  • null_policy - a literal which determines how to deal with missing data
  • add_intercept - a boolean specifying if an intercept feature should be added to the features
  • sample_weights - a column or expression providing non-negative weights applied to the samples

Remaining parameters are model specific, for example alpha penalty parameter used by regularized least squares models.

See below for basic usage examples. Please refer to the tests or demo notebook for detailed examples.

import polars as pl
import polars_ols as pls  # registers 'least_squares' namespace

df = pl.DataFrame({"y": [1.16, -2.16, -1.57, 0.21, 0.22, 1.6, -2.11, -2.92, -0.86, 0.47],
                   "x1": [0.72, -2.43, -0.63, 0.05, -0.07, 0.65, -0.02, -1.64, -0.92, -0.27],
                   "x2": [0.24, 0.18, -0.95, 0.23, 0.44, 1.01, -2.08, -1.36, 0.01, 0.75],
                   "group": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
                   "weights": [0.34, 0.97, 0.39, 0.8, 0.57, 0.41, 0.19, 0.87, 0.06, 0.34],
                   })

lasso_expr = pl.col("y").least_squares.lasso(pl.col("x1"), pl.col("x2"), alpha=0.0001, add_intercept=True).over("group")
wls_expr = pls.compute_least_squares_from_formula("y ~ x1 + x2 -1", sample_weights=pl.col("weights"))

predictions = df.with_columns(lasso_expr.round(2).alias("predictions_lasso"),
                              wls_expr.round(2).alias("predictions_wls"))

print(predictions.head(5))
shape: (5, 7)
┌───────┬───────┬───────┬───────┬─────────┬───────────────────┬─────────────────┐
│ y     ┆ x1    ┆ x2    ┆ group ┆ weights ┆ predictions_lasso ┆ predictions_wls │
│ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---     ┆ ---               ┆ ---             │
│ f64   ┆ f64   ┆ f64   ┆ i64   ┆ f64     ┆ f64               ┆ f64             │
╞═══════╪═══════╪═══════╪═══════╪═════════╪═══════════════════╪═════════════════╡
│ 1.16  ┆ 0.72  ┆ 0.24  ┆ 1     ┆ 0.34    ┆ 0.97              ┆ 0.93            │
│ -2.16 ┆ -2.43 ┆ 0.18  ┆ 1     ┆ 0.97    ┆ -2.23             ┆ -2.18           │
│ -1.57 ┆ -0.63 ┆ -0.95 ┆ 1     ┆ 0.39    ┆ -1.54             ┆ -1.54           │
│ 0.21  ┆ 0.05  ┆ 0.23  ┆ 1     ┆ 0.8     ┆ 0.29              ┆ 0.27            │
│ 0.22  ┆ -0.07 ┆ 0.44  ┆ 1     ┆ 0.57    ┆ 0.37              ┆ 0.36            │
└───────┴───────┴───────┴───────┴─────────┴───────────────────┴─────────────────┘

The mode parameter is used to set the type of the output returned by all methods ("predictions", "residuals", "coefficients"). It defaults to returning predictions matching the input's length.

In case "coefficients" is set the output is a polars Struct with coefficients as values and feature names as fields. It's output shape 'broadcasts' depending on context, see below:

coefficients = df.select(pl.col("y").least_squares.from_formula("x1 + x2", mode="coefficients")
                         .alias("coefficients"))

coefficients_group = df.select("group", pl.col("y").least_squares.from_formula("x1 + x2", mode="coefficients").over("group")
                        .alias("coefficients_group")).unique(maintain_order=True)

print(coefficients)
print(coefficients_group)
shape: (1, 1)
┌──────────────────────────────┐
│ coefficients                 │
│ ---                          │
│ struct[3]                    │
╞══════════════════════════════╡
│ {0.977375,0.987413,0.000757} │  # <--- coef for x1, x2, and intercept added by formula API
└──────────────────────────────┘
shape: (2, 2)
┌───────┬───────────────────────────────┐
│ group ┆ coefficients_group            │
│ ---   ┆ ---                           │
│ i64   ┆ struct[3]                     │
╞═══════╪═══════════════════════════════╡
│ 1     ┆ {0.995157,0.977495,0.014344}  │
│ 2     ┆ {0.939217,0.997441,-0.017599} │  # <--- (unique) coefficients per group
└───────┴───────────────────────────────┘

For dynamic models (like rolling_ols) or if in a .over, .group_by, or .with_columns context, the coefficients will take the shape of the data it is applied on. For example:

coefficients = df.with_columns(pl.col("y").least_squares.rls(pl.col("x1"), pl.col("x2"), mode="coefficients")
                         .over("group").alias("coefficients"))

print(coefficients.head())
shape: (5, 6)
┌───────┬───────┬───────┬───────┬─────────┬─────────────────────┐
│ y     ┆ x1    ┆ x2    ┆ group ┆ weights ┆ coefficients        │
│ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---     ┆ ---                 │
│ f64   ┆ f64   ┆ f64   ┆ i64   ┆ f64     ┆ struct[2]           │
╞═══════╪═══════╪═══════╪═══════╪═════════╪═════════════════════╡
│ 1.16  ┆ 0.72  ┆ 0.24  ┆ 1     ┆ 0.34    ┆ {1.235503,0.411834} │
│ -2.16 ┆ -2.43 ┆ 0.18  ┆ 1     ┆ 0.97    ┆ {0.963515,0.760769} │
│ -1.57 ┆ -0.63 ┆ -0.95 ┆ 1     ┆ 0.39    ┆ {0.975484,0.966029} │
│ 0.21  ┆ 0.05  ┆ 0.23  ┆ 1     ┆ 0.8     ┆ {0.975657,0.953735} │
│ 0.22  ┆ -0.07 ┆ 0.44  ┆ 1     ┆ 0.57    ┆ {0.97898,0.909793}  │
└───────┴───────┴───────┴───────┴─────────┴─────────────────────┘

Finally, for convenience, in order to compute out-of-sample predictions you can use: least_squares.{predict, predict_from_formula}. This saves you the effort of un-nesting the coefficients and doing the dot product in python and instead does this in Rust, as an expression. Usage is as follows:

df_test.select(pl.col("coefficients_train").least_squares.predict(pl.col("x1"), pl.col("x2")).alias("predictions_test"))

Supported Models

Currently, this extension package supports the following variants:

  • Ordinary Least Squares: least_squares.ols
  • Weighted Least Squares: least_squares.wls
  • Regularized Least Squares (Lasso / Ridge / Elastic Net) least_squares.{lasso, ridge, elastic_net}
  • Non-negative Least Squares: least_squares.nnls

As well as efficient implementations of moving window models:

  • Recursive Least Squares: least_squares.rls
  • Rolling / Expanding Window OLS: least_squares.{rolling_ols, expanding_ols}

An arbitrary combination of sample_weights, L1/L2 penalties, and non-negativity constraints can be specified with the least_squares.from_formula and least_squares.least_squares entry-points.

Solve Methods

polars-ols provides a choice over multiple supported numerical approaches per model (via solve_method flag), with implications on performance vs numerical accuracy. These choices are exposed to the user for full control, however, if left unspecified the package will choose a reasonable default depending on context.

For example, if you know you are dealing with highly collinear data, with unregularized OLS model, you may want to explicitly set solve_method="svd" so that the minimum norm solution is obtained.

Benchmark

The usual caveats of benchmarks apply here, but the below should still be indicative of the type of performance improvements to expect when using this package.

This benchmark was run on randomly generated data with pyperf on my Apple M2 Max macbook (32GB RAM, MacOS Sonoma 14.2.1). See benchmark.py for implementation.

n_samples=2_000, n_features=5

Model polars_ols Python Benchmark Benchmark Type Speed-up vs Python Benchmark
Least Squares (QR) 195 µs ± 6 µs 466 µs ± 104 µs Numpy (QR) 2.4x
Least Squares (SVD) 247 µs ± 5 µs 395 µs ± 69 µs Numpy (SVD) 1.6x
Ridge (Cholesky) 171 µs ± 8 µs 1.02 ms ± 0.29 ms Sklearn (Cholesky) 5.9x
Ridge (SVD) 238 µs ± 7 µs 1.12 ms ± 0.41 ms Sklearn (SVD) 4.7x
Weighted Least Squares 334 µs ± 13 µs 2.04 ms ± 0.22 ms Statsmodels 6.1x
Elastic Net (CD) 227 µs ± 7 µs 1.18 ms ± 0.19 ms Sklearn 5.2x
Recursive Least Squares 1.12 ms ± 0.23 ms 18.2 ms ± 1.6 ms Statsmodels 16.2x
Rolling Least Squares 1.99 ms ± 0.03 ms 22.1 ms ± 0.2 ms Statsmodels 11.1x

n_samples=10_000, n_features=100

Model polars_ols Python Benchmark Benchmark Type Speed-up vs Python Benchmark
Least Squares (QR) 17.6 ms ± 0.3 ms 44.4 ms ± 9.3 ms Numpy (QR) 2.5x
Least Squares (SVD) 23.8 ms ± 0.2 ms 26.6 ms ± 5.5 ms Numpy (SVD) 1.1x
Ridge (Cholesky) 5.36 ms ± 0.16 ms 475 ms ± 71 ms Sklearn (Cholesky) 88.7x
Ridge (SVD) 30.2 ms ± 0.4 ms 400 ms ± 48 ms Sklearn (SVD) 13.2x
Weighted Least Squares 18.8 ms ± 0.3 ms 80.4 ms ± 12.4 ms Statsmodels 4.3x
Elastic Net (CD) 22.7 ms ± 0.2 ms 138 ms ± 27 ms Sklearn 6.1x
Recursive Least Squares 270 ms ± 53 ms 57.8 sec ± 43.7 sec Statsmodels 1017.0x
Rolling Least Squares 371 ms ± 13 ms 4.41 sec ± 0.17 sec Statsmodels 11.9x
  • Numpy's lstsq (uses divide-and-conquer SVD) is already a highly optimized call into LAPACK and so the scope for speed-up is relatively limited, and the same applies to simple approaches like directly solving normal equations with Cholesky.
  • However, even in such problems polars-ols Rust implementations for matching numerical algorithms tend to outperform by ~2-3x
  • More substantial speed-up is achieved for the more complex models by working entirely in rust and avoiding overhead from back and forth into python.
  • Expect a large additional relative order-of-magnitude speed up to your workflow if it involved repeated re-estimation of models in (python) loops.

Credits & Related Projects

  • Rust linear algebra libraries faer and ndarray support the implementations provided by this extension package
  • This package was templated around the very helpful: polars-plugin-tutorial
  • The python package patsy is used for (optionally) building models from formulae
  • Please check out the extension package polars-ds for general data-science functionality in polars

Future Work / TODOs

  • Support generic types, in rust implementations, so that both f32 and f64 types are recognized. Right now data is cast to f64 prior to estimation
  • Add docs explaining supported models, signatures, and API

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

polars_ols-0.3.1.tar.gz (76.6 kB view details)

Uploaded Source

Built Distributions

polars_ols-0.3.1-cp38-abi3-win_amd64.whl (10.2 MB view details)

Uploaded CPython 3.8+ Windows x86-64

polars_ols-0.3.1-cp38-abi3-win32.whl (8.8 MB view details)

Uploaded CPython 3.8+ Windows x86

polars_ols-0.3.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.5 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ x86-64

polars_ols-0.3.1-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl (13.4 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ i686

polars_ols-0.3.1-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl (12.1 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ ARMv7l

polars_ols-0.3.1-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (12.0 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ ARM64

polars_ols-0.3.1-cp38-abi3-macosx_11_0_arm64.whl (9.7 MB view details)

Uploaded CPython 3.8+ macOS 11.0+ ARM64

polars_ols-0.3.1-cp38-abi3-macosx_10_12_x86_64.whl (10.9 MB view details)

Uploaded CPython 3.8+ macOS 10.12+ x86-64

File details

Details for the file polars_ols-0.3.1.tar.gz.

File metadata

  • Download URL: polars_ols-0.3.1.tar.gz
  • Upload date:
  • Size: 76.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: maturin/1.5.1

File hashes

Hashes for polars_ols-0.3.1.tar.gz
Algorithm Hash digest
SHA256 34786d3f83e9867be3239c39f84c0020607df4277618c0e0efc6928612baf151
MD5 1cb6f5b296ee9d0543fc7c4cb0b27ec2
BLAKE2b-256 bbbf149e101fe31560736e539e0541baf82757bacf3a36193afb1a847a5f2afb

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.1-cp38-abi3-win_amd64.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.1-cp38-abi3-win_amd64.whl
Algorithm Hash digest
SHA256 990c1e91ea3606be000c71d174015c1f794d3c0774a335af4ea4fbd36bebc534
MD5 6ef8407d8f5299b17c1bcc95348408fb
BLAKE2b-256 a95c3085dcccb97e36717244881f2e9e792ac0947cc7259a32a0e43773bc35c2

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.1-cp38-abi3-win32.whl.

File metadata

  • Download URL: polars_ols-0.3.1-cp38-abi3-win32.whl
  • Upload date:
  • Size: 8.8 MB
  • Tags: CPython 3.8+, Windows x86
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: maturin/1.5.1

File hashes

Hashes for polars_ols-0.3.1-cp38-abi3-win32.whl
Algorithm Hash digest
SHA256 e0fa9651d63026ba671c1fa74e922d206ffdc4425457f16fab173e726e940da2
MD5 49e252199bbdfe069e86d049c60f5967
BLAKE2b-256 d9494bcaa8ab0e2f42e4ba6eadbed1039fae12a4937f14f86b794686f1056d9f

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 e2120f845d7ac6fb0ac29f71057557d136c6b3adafe9ac9985293ffb130d6df7
MD5 9f642364cd00dbd76be6203f94acd181
BLAKE2b-256 f3aaf96a8c162104b24dba5b6881cc4e5c324bf6e94ca11a6ea3601e3e284c40

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.1-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.1-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl
Algorithm Hash digest
SHA256 58b9f0766b570cefba006f58fadcba918ab464d663c7ec49d82b4f18c7cc441d
MD5 390a1e464e2b7e0b5a866926dca7d0c5
BLAKE2b-256 bb0b502b2ba32b03a4305870842168c042fbaab78be2304009685eeb59c74798

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.1-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.1-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl
Algorithm Hash digest
SHA256 aa223212580398e8c23e6d2e4e0b300ed3269c6fabb61bb49d5e398c768926b7
MD5 29ebca7392ffcadea9f28a0968705555
BLAKE2b-256 d41b0834ed196c4bd7a83709ff30955d53fa98e3be6433cdbbcc288591dba86d

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.1-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.1-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 42c2428db0a244646b73141b58b934a9fd42a833308be90fbc1ca766d9203a0a
MD5 0d77f9e21fa7415c4ddc105828e021e2
BLAKE2b-256 c178043bced7a3debe3c1dab02ffbe486d0ef399a2e04d48f6884939eb63704d

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.1-cp38-abi3-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.1-cp38-abi3-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 4f852237cb12a8d436f9ec35d6d8a9a59e096dc0610e7633e83898a8aa70a0e6
MD5 2bc2cc0fae78a09154af8b7efd5cdaa6
BLAKE2b-256 9fc119c07658f330b124e05ec68284ca2307276572254e5834657f420e26e659

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.1-cp38-abi3-macosx_10_12_x86_64.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.1-cp38-abi3-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 d117804be588ad3b8e11231943da5fbdbebee04fe5dce1ab3a7b2d9d0c39f973
MD5 6a1d8ec610ad1b37fbfacf41ff5e30b1
BLAKE2b-256 5568f5696d7be97664c2411b87c1df72b785a69fa215b45311e904792b7e8243

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page