Skip to main content

Polars Least Squares Extension

Project description

Polars OLS

Least squares extension in Polars

Supports linear model estimation in Polars.

This package provides efficient rust implementations of common linear regression variants (OLS, WLS, Ridge, Elastic Net, Non-negative least squares, Recursive least squares) and exposes them as simple polars expressions which can easily be integrated into your workflow.

Why?

  1. High Performance: implementations are written in rust and make use of optimized rust linear-algebra crates & LAPACK routines. See benchmark section.
  2. Polars Integration: avoids unnecessary conversions from lazy to eager mode and to external libraries (e.g. numpy, sklearn) to do simple linear regressions. Chain least squares formulae like any other expression in polars.
  3. Efficient Implementations:
    • Numerically stable algorithms are chosen where appropriate (e.g. QR, Cholesky).
    • Flexible model specification allows arbitrary combination of sample weighting, L1/L2 regularization, & non-negativity constraints on parameters.
    • Efficient rank-1 update algorithms used for moving window regressions.
  4. Easy Parallelism: Computing OLS predictions, in parallel, across groups can not be easier: call .over() or group_by just like any other polars' expression and benefit from full Rust parallelism.
  5. Formula API: supports building models via patsy syntax: y ~ x1 + x2 + x3:x4 -1 (like statsmodels) which automatically converts to equivalent polars expressions.

Installation

First, you need to install Polars. Then run the below to install the polars-ols extension:

pip install polars-ols

API & Examples

Importing polars_ols will register the namespace least_squares provided by this package. You can build models either by either specifying polars expressions (e.g. pl.col(...)) for your targets and features or using the formula api (patsy syntax). All models support the following general (optional) arguments:

  • mode - a literal which determines the type of output produced by the model
  • null_policy - a literal which determines how to deal with missing data
  • add_intercept - a boolean specifying if an intercept feature should be added to the features
  • sample_weights - a column or expression providing non-negative weights applied to the samples

Remaining parameters are model specific, for example alpha penalty parameter used by regularized least squares models.

See below for basic usage examples. Please refer to the tests or demo notebook for detailed examples.

import polars as pl
import polars_ols as pls  # registers 'least_squares' namespace

df = pl.DataFrame({"y": [1.16, -2.16, -1.57, 0.21, 0.22, 1.6, -2.11, -2.92, -0.86, 0.47],
                   "x1": [0.72, -2.43, -0.63, 0.05, -0.07, 0.65, -0.02, -1.64, -0.92, -0.27],
                   "x2": [0.24, 0.18, -0.95, 0.23, 0.44, 1.01, -2.08, -1.36, 0.01, 0.75],
                   "group": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
                   "weights": [0.34, 0.97, 0.39, 0.8, 0.57, 0.41, 0.19, 0.87, 0.06, 0.34],
                   })

lasso_expr = pl.col("y").least_squares.lasso(pl.col("x1"), pl.col("x2"), alpha=0.0001, add_intercept=True).over("group")
wls_expr = pls.compute_least_squares_from_formula("y ~ x1 + x2 -1", sample_weights=pl.col("weights"))

predictions = df.with_columns(lasso_expr.round(2).alias("predictions_lasso"),
                              wls_expr.round(2).alias("predictions_wls"))

print(predictions.head(5))
shape: (5, 7)
┌───────┬───────┬───────┬───────┬─────────┬───────────────────┬─────────────────┐
│ y     ┆ x1    ┆ x2    ┆ group ┆ weights ┆ predictions_lasso ┆ predictions_wls │
│ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---     ┆ ---               ┆ ---             │
│ f64   ┆ f64   ┆ f64   ┆ i64   ┆ f64     ┆ f64               ┆ f64             │
╞═══════╪═══════╪═══════╪═══════╪═════════╪═══════════════════╪═════════════════╡
│ 1.16  ┆ 0.72  ┆ 0.24  ┆ 1     ┆ 0.34    ┆ 0.97              ┆ 0.93            │
│ -2.16 ┆ -2.43 ┆ 0.18  ┆ 1     ┆ 0.97    ┆ -2.23             ┆ -2.18           │
│ -1.57 ┆ -0.63 ┆ -0.95 ┆ 1     ┆ 0.39    ┆ -1.54             ┆ -1.54           │
│ 0.21  ┆ 0.05  ┆ 0.23  ┆ 1     ┆ 0.8     ┆ 0.29              ┆ 0.27            │
│ 0.22  ┆ -0.07 ┆ 0.44  ┆ 1     ┆ 0.57    ┆ 0.37              ┆ 0.36            │
└───────┴───────┴───────┴───────┴─────────┴───────────────────┴─────────────────┘

The mode parameter is used to set the type of the output returned by all methods ("predictions", "residuals", "coefficients"). It defaults to returning predictions matching the input's length.

In case "coefficients" is set the output is a polars Struct with coefficients as values and feature names as fields. It's output shape 'broadcasts' depending on context, see below:

coefficients = df.select(pl.col("y").least_squares.from_formula("x1 + x2", mode="coefficients")
                         .alias("coefficients"))

coefficients_group = df.select("group", pl.col("y").least_squares.from_formula("x1 + x2", mode="coefficients").over("group")
                        .alias("coefficients_group")).unique(maintain_order=True)

print(coefficients)
print(coefficients_group)
shape: (1, 1)
┌──────────────────────────────┐
│ coefficients                 │
│ ---                          │
│ struct[3]                    │
╞══════════════════════════════╡
│ {0.977375,0.987413,0.000757} │  # <--- coef for x1, x2, and intercept added by formula API
└──────────────────────────────┘
shape: (2, 2)
┌───────┬───────────────────────────────┐
│ group ┆ coefficients_group            │
│ ---   ┆ ---                           │
│ i64   ┆ struct[3]                     │
╞═══════╪═══════════════════════════════╡
│ 1     ┆ {0.995157,0.977495,0.014344}  │
│ 2     ┆ {0.939217,0.997441,-0.017599} │  # <--- (unique) coefficients per group
└───────┴───────────────────────────────┘

For dynamic models (like rolling_ols) or if in a .over, .group_by, or .with_columns context, the coefficients will take the shape of the data it is applied on. For example:

coefficients = df.with_columns(pl.col("y").least_squares.rls(pl.col("x1"), pl.col("x2"), mode="coefficients")
                         .over("group").alias("coefficients"))

print(coefficients.head())
shape: (5, 6)
┌───────┬───────┬───────┬───────┬─────────┬─────────────────────┐
│ y     ┆ x1    ┆ x2    ┆ group ┆ weights ┆ coefficients        │
│ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---     ┆ ---                 │
│ f64   ┆ f64   ┆ f64   ┆ i64   ┆ f64     ┆ struct[2]           │
╞═══════╪═══════╪═══════╪═══════╪═════════╪═════════════════════╡
│ 1.16  ┆ 0.72  ┆ 0.24  ┆ 1     ┆ 0.34    ┆ {1.235503,0.411834} │
│ -2.16 ┆ -2.43 ┆ 0.18  ┆ 1     ┆ 0.97    ┆ {0.963515,0.760769} │
│ -1.57 ┆ -0.63 ┆ -0.95 ┆ 1     ┆ 0.39    ┆ {0.975484,0.966029} │
│ 0.21  ┆ 0.05  ┆ 0.23  ┆ 1     ┆ 0.8     ┆ {0.975657,0.953735} │
│ 0.22  ┆ -0.07 ┆ 0.44  ┆ 1     ┆ 0.57    ┆ {0.97898,0.909793}  │
└───────┴───────┴───────┴───────┴─────────┴─────────────────────┘

Finally, for convenience, in order to compute out-of-sample predictions you can use: least_squares.{predict, predict_from_formula}. This saves you the effort of un-nesting the coefficients and doing the dot product in python and instead does this in Rust, as an expression. Usage is as follows:

df_test.select(pl.col("coefficients_train").least_squares.predict(pl.col("x1"), pl.col("x2")).alias("predictions_test"))

Supported Models

Currently, this extension package supports the following variants:

  • Ordinary Least Squares: least_squares.ols
  • Weighted Least Squares: least_squares.wls
  • Regularized Least Squares (Lasso / Ridge / Elastic Net) least_squares.{lasso, ridge, elastic_net}
  • Non-negative Least Squares: least_squares.nnls

As well as efficient implementations of moving window models:

  • Recursive Least Squares: least_squares.rls
  • Rolling / Expanding Window OLS: least_squares.{rolling_ols, expanding_ols}

An arbitrary combination of sample_weights, L1/L2 penalties, and non-negativity constraints can be specified with the least_squares.from_formula and least_squares.least_squares entry-points.

Solve Methods

polars-ols provides a choice over multiple supported numerical approaches per model (via solve_method flag), with implications on performance vs numerical accuracy. These choices are exposed to the user for full control, however, if left unspecified the package will choose a reasonable default depending on context.

For example, if you know you are dealing with highly collinear data, with unregularized OLS model, you may want to explicitly set solve_method="svd" so that the minimum norm solution is obtained.

Benchmark

The usual caveats of benchmarks apply here, but the below should still be indicative of the type of performance improvements to expect when using this package.

This benchmark was run on randomly generated data with pyperf on my Apple M2 Max macbook (32GB RAM, MacOS Sonoma 14.2.1). See benchmark.py for implementation.

n_samples=2_000, n_features=5

Model polars_ols Python Benchmark Benchmark Type Speed-up vs Python Benchmark
Least Squares (QR) 300 us ± 7 us 1.01 ms ± 0.81 ms Numpy (QR) 3.4x
Least Squares (SVD) 351 us ± 4 us 853 us ± 417 us Numpy (SVD) 2.4x
Ridge (Cholesky) 279 us ± 6 us 1.63 ms ± 0.69 ms Sklearn (Cholesky) 5.8x
Ridge (SVD) 351 us ± 5 us 1.95 ms ± 1.12 ms Sklearn (SVD) 5.6x
Weighted Least Squares 531 us ± 4 us 2.54 ms ± 0.40 ms Statsmodels 4.8x
Elastic Net (CD) 339 us ± 5 us 2.17 ms ± 0.77 ms Sklearn 6.4x
Recursive Least Squares 1.42 ms ± 0.02 ms 18.5 ms ± 1.4 ms Statsmodels 13.0x
Rolling Least Squares 2.78 ms ± 0.07 ms 22.8 ms ± 0.2 ms Statsmodels 8.2x

n_samples=10_000, n_features=100

Model polars_ols Python Benchmark Benchmark Type Speed-up vs Python Benchmark
Least Squares (QR) 12.4 ms ± 0.2 ms 68.3 ms ± 13.7 ms Numpy (QR) 5.5x
Least Squares (SVD) 14.5 ms ± 0.5 ms 44.9 ms ± 10.3 ms Numpy (SVD) 3.1x
Ridge (Cholesky) 6.10 ms ± 0.14 ms 9.91 ms ± 2.86 ms Sklearn (Cholesky) 1.6x
Ridge (SVD) 24.9 ms ± 2.1 ms 390 ms ± 63 ms Sklearn (SVD) 15.7x
Weighted Least Squares 14.8 ms ± 2.4 ms 114 ms ± 35 ms Statsmodels 7.7x
Elastic Net (CD) 21.7 ms ± 1.2 ms 111 ms ± 54 ms Sklearn 5.1x
Recursive Least Squares 163 ms ± 28 ms 65.7 sec ± 28.2 sec Statsmodels 403.1x
Rolling Least Squares 390 ms ± 10 ms 3.99 sec ± 0.54 sec Statsmodels 10.2x
  • Numpy's lstsq (uses divide-and-conquer SVD) is already a highly optimized call into LAPACK and so the scope for speed-up is relatively limited, and the same applies to simple approaches like directly solving normal equations with Cholesky.
  • However, even in such problems polars-ols Rust implementations for matching numerical algorithms tend to outperform by ~2-3x
  • More substantial speed-up is achieved for the more complex models by working entirely in rust and avoiding overhead from back and forth into python.
  • Expect a large additional relative order-of-magnitude speed up to your workflow if it involved repeated re-estimation of models in (python) loops.

Credits & Related Projects

  • Rust linear algebra libraries faer and ndarray support the implementations provided by this extension package
  • This package was templated around the very helpful: polars-plugin-tutorial
  • The python package patsy is used for (optionally) building models from formulae
  • Please check out the extension package polars-ds for general data-science functionality in polars

Future Work / TODOs

  • Support generic types, in rust implementations, so that both f32 and f64 types are recognized. Right now data is cast to f32 prior to estimation
  • Add more detailed documentation on supported models, signatures, and API

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

polars_ols-0.2.9.tar.gz (61.8 kB view details)

Uploaded Source

Built Distributions

polars_ols-0.2.9-cp38-abi3-win_amd64.whl (4.2 MB view details)

Uploaded CPython 3.8+ Windows x86-64

polars_ols-0.2.9-cp38-abi3-win32.whl (3.6 MB view details)

Uploaded CPython 3.8+ Windows x86

polars_ols-0.2.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.9 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ x86-64

polars_ols-0.2.9-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl (8.1 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ i686

polars_ols-0.2.9-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl (7.2 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ ARMv7l

polars_ols-0.2.9-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (7.3 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ ARM64

polars_ols-0.2.9-cp38-abi3-macosx_11_0_arm64.whl (3.9 MB view details)

Uploaded CPython 3.8+ macOS 11.0+ ARM64

polars_ols-0.2.9-cp38-abi3-macosx_10_12_x86_64.whl (4.4 MB view details)

Uploaded CPython 3.8+ macOS 10.12+ x86-64

File details

Details for the file polars_ols-0.2.9.tar.gz.

File metadata

  • Download URL: polars_ols-0.2.9.tar.gz
  • Upload date:
  • Size: 61.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: maturin/1.5.1

File hashes

Hashes for polars_ols-0.2.9.tar.gz
Algorithm Hash digest
SHA256 512dcbcdc7d1f40a4b09a62df2b77b583589ddb9cef48ba7746d4e5ade88616c
MD5 b26becd7e5e551d3863d86d1d5f345d1
BLAKE2b-256 05b92bf1ac3d1655581fa69bf11dab583858a833af41a1d461b2878dafb7a830

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.9-cp38-abi3-win_amd64.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.9-cp38-abi3-win_amd64.whl
Algorithm Hash digest
SHA256 6a0a6e1bf7eeef10d23a8b6e4f955f99ed1b25f546e65295f86d11caca383810
MD5 88c11544376a6e62d32f3af400caf3dc
BLAKE2b-256 b74e4449a88108edea2c2757ab1fe3c6e4c8d31dd639625933b3482bf85dbc2c

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.9-cp38-abi3-win32.whl.

File metadata

  • Download URL: polars_ols-0.2.9-cp38-abi3-win32.whl
  • Upload date:
  • Size: 3.6 MB
  • Tags: CPython 3.8+, Windows x86
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: maturin/1.5.1

File hashes

Hashes for polars_ols-0.2.9-cp38-abi3-win32.whl
Algorithm Hash digest
SHA256 ed02935cf02421265183a74d40715f495984bfb22fb2a6b1662b3be02fec41ad
MD5 5e0a677a3100d45fccf56a3137d88c3a
BLAKE2b-256 25e17ce05c7375a5c978f8311c77830c13633e0a5ee71ea25304c5a21fa5fa29

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 692cfc7a88c0d48e3927eafb59abf1f45e5be4761d2a3aa4b81a34bb1082a6c0
MD5 c7d016294e581d842069b322282adbb2
BLAKE2b-256 ae4c5d53172fe4c3b92e9641a0b5506f7b21dac3c143a91ecafd5ab6696273ba

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.9-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.9-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl
Algorithm Hash digest
SHA256 5572c44d8ff655a987466658da6af90c397660cd8e60cef781f63906cd453cad
MD5 75b35a4198f9d314f192fc5553a8905c
BLAKE2b-256 d3580ee6f59ece1740cbf23581c338c8e374cfdb547ef19eb745e0b5a1142513

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.9-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.9-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl
Algorithm Hash digest
SHA256 ffdf373761703e44c873ea178a4e14adb9cd1f3f286c80965b2d349d1add4efc
MD5 4d606660baa70e2afeccf4e9813fb98c
BLAKE2b-256 c11d372ff489d6155bbd289e9b6b4fb79a0e48aeb80dc3913e5f5107fdcfe29a

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.9-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.9-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 ca4fc8e8765bcfabd5ebd99e57798c40e0fd4a818612b4ef6d51b0a016f6a5da
MD5 18c7ff684387f9a65e50e9731b805517
BLAKE2b-256 f98059cd1a096eb9f98d79aaeefba042cf2c4e14158915a4d74d343c90267f0c

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.9-cp38-abi3-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.9-cp38-abi3-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 8ba0d55fa7d96229adf8152a13ebaf9038472bcf06e1f6d65ec26e68c438478c
MD5 155a342cd11bfcabb66058da09b8d61c
BLAKE2b-256 3f3b98aa7335ba730f4fd793d13f1458e9437ab7d7168f40f3c263d816ec6584

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.9-cp38-abi3-macosx_10_12_x86_64.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.9-cp38-abi3-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 571155c0b900bcab4fc9992ba76c33c15de2dc11d3ab8260de24c3bc23265f7f
MD5 22a0a2444a5f6d3654735d3ee4e7584c
BLAKE2b-256 5538eeaace9a500fcc2a100528233ff0648f8578f2e296663f254a7754d27af1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page