Skip to main content

Polars Least Squares Extension

Project description

Polars OLS

Least squares extension in Polars

Supports linear model estimation in Polars.

This package provides efficient rust implementations of common linear regression variants (OLS, WLS, Ridge, Elastic Net, Non-negative least squares, Recursive least squares) and exposes them as simple polars expressions which can easily be integrated into your workflow.

Why?

  1. High Performance: implementations are written in rust and make use of optimized rust linear-algebra crates & LAPACK routines. See benchmark section.
  2. Polars Integration: avoids unnecessary conversions from lazy to eager mode and to external libraries (e.g. numpy, sklearn) to do simple linear regressions. Chain least squares formulae like any other expression in polars.
  3. Efficient Implementations:
    • Numerically stable algorithms are chosen where appropriate (e.g. QR, Cholesky).
    • Flexible model specification allows arbitrary combination of sample weighting, L1/L2 regularization, & non-negativity constraints on parameters.
    • Efficient rank-1 update algorithms used for moving window regressions.
  4. Easy Parallelism: Computing OLS predictions, in parallel, across groups can not be easier: call .over() or group_by just like any other polars' expression and benefit from full Rust parallelism.
  5. Formula API: supports building models via patsy syntax: y ~ x1 + x2 + x3:x4 -1 (like statsmodels) which automatically converts to equivalent polars expressions.

Installation

First, you need to install Polars. Then run the below to install the polars-ols extension:

pip install polars-ols

API & Examples

Importing polars_ols will register the namespace least_squares provided by this package. You can build models either by either specifying polars expressions (e.g. pl.col(...)) for your targets and features or using the formula api (patsy syntax). All models support the following general (optional) arguments:

  • mode - a literal which determines the type of output produced by the model
  • null_policy - a literal which determines how to deal with missing data
  • add_intercept - a boolean specifying if an intercept feature should be added to the features
  • sample_weights - a column or expression providing non-negative weights applied to the samples

Remaining parameters are model specific, for example alpha penalty parameter used by regularized least squares models.

See below for basic usage examples. Please refer to the tests or demo notebook for detailed examples.

import polars as pl
import polars_ols as pls  # registers 'least_squares' namespace

df = pl.DataFrame({"y": [1.16, -2.16, -1.57, 0.21, 0.22, 1.6, -2.11, -2.92, -0.86, 0.47],
                   "x1": [0.72, -2.43, -0.63, 0.05, -0.07, 0.65, -0.02, -1.64, -0.92, -0.27],
                   "x2": [0.24, 0.18, -0.95, 0.23, 0.44, 1.01, -2.08, -1.36, 0.01, 0.75],
                   "group": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
                   "weights": [0.34, 0.97, 0.39, 0.8, 0.57, 0.41, 0.19, 0.87, 0.06, 0.34],
                   })

lasso_expr = pl.col("y").least_squares.lasso(pl.col("x1"), pl.col("x2"), alpha=0.0001, add_intercept=True).over("group")
wls_expr = pls.compute_least_squares_from_formula("y ~ x1 + x2 -1", sample_weights=pl.col("weights"))

predictions = df.with_columns(lasso_expr.round(2).alias("predictions_lasso"),
                              wls_expr.round(2).alias("predictions_wls"))

print(predictions.head(5))
shape: (5, 7)
┌───────┬───────┬───────┬───────┬─────────┬───────────────────┬─────────────────┐
│ y     ┆ x1    ┆ x2    ┆ group ┆ weights ┆ predictions_lasso ┆ predictions_wls │
│ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---     ┆ ---               ┆ ---             │
│ f64   ┆ f64   ┆ f64   ┆ i64   ┆ f64     ┆ f32               ┆ f32             │
╞═══════╪═══════╪═══════╪═══════╪═════════╪═══════════════════╪═════════════════╡
│ 1.16  ┆ 0.72  ┆ 0.24  ┆ 1     ┆ 0.34    ┆ 0.97              ┆ 0.93            │
│ -2.16 ┆ -2.43 ┆ 0.18  ┆ 1     ┆ 0.97    ┆ -2.23             ┆ -2.18           │
│ -1.57 ┆ -0.63 ┆ -0.95 ┆ 1     ┆ 0.39    ┆ -1.54             ┆ -1.54           │
│ 0.21  ┆ 0.05  ┆ 0.23  ┆ 1     ┆ 0.8     ┆ 0.29              ┆ 0.27            │
│ 0.22  ┆ -0.07 ┆ 0.44  ┆ 1     ┆ 0.57    ┆ 0.37              ┆ 0.36            │
└───────┴───────┴───────┴───────┴─────────┴───────────────────┴─────────────────┘

The mode parameter is used to set the type of the output returned by all methods ("predictions", "residuals", "coefficients"). It defaults to returning predictions matching the input's length.

In case "coefficients" is set the output is a polars Struct with coefficients as values and feature names as fields. It's output shape 'broadcasts' depending on context, see below:

coefficients = df.select(pl.col("y").least_squares.from_formula("x1 + x2", mode="coefficients")
                         .alias("coefficients"))

coefficients_group = df.select("group", pl.col("y").least_squares.from_formula("x1 + x2", mode="coefficients").over("group")
                        .alias("coefficients_group")).unique(maintain_order=True)

print(coefficients)
print(coefficients_group)
shape: (1, 1)
┌──────────────────────────────┐
│ coefficients                 │
│ ---                          │
│ struct[3]                    │
╞══════════════════════════════╡
│ {0.977375,0.987413,0.000757} │  # <--- coef for x1, x2, and intercept added by formula API
└──────────────────────────────┘
shape: (2, 2)
┌───────┬───────────────────────────────┐
│ group ┆ coefficients_group            │
│ ---   ┆ ---                           │
│ i64   ┆ struct[3]                     │
╞═══════╪═══════════════════════════════╡
│ 1     ┆ {0.995157,0.977495,0.014344}  │
│ 2     ┆ {0.939217,0.997441,-0.017599} │  # <--- (unique) coefficients per group
└───────┴───────────────────────────────┘

For dynamic models (like rolling_ols) or if in a .over, .group_by, or .with_columns context, the coefficients will take the shape of the data it is applied on. For example:

coefficients = df.with_columns(pl.col("y").least_squares.rls(pl.col("x1"), pl.col("x2"), mode="coefficients")
                         .over("group").alias("coefficients"))

print(coefficients.head())
shape: (5, 6)
┌───────┬───────┬───────┬───────┬─────────┬─────────────────────┐
│ y     ┆ x1    ┆ x2    ┆ group ┆ weights ┆ coefficients        │
│ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---     ┆ ---                 │
│ f64   ┆ f64   ┆ f64   ┆ i64   ┆ f64     ┆ struct[2]           │
╞═══════╪═══════╪═══════╪═══════╪═════════╪═════════════════════╡
│ 1.16  ┆ 0.72  ┆ 0.24  ┆ 1     ┆ 0.34    ┆ {1.235503,0.411834} │
│ -2.16 ┆ -2.43 ┆ 0.18  ┆ 1     ┆ 0.97    ┆ {0.963515,0.760769} │
│ -1.57 ┆ -0.63 ┆ -0.95 ┆ 1     ┆ 0.39    ┆ {0.975484,0.966029} │
│ 0.21  ┆ 0.05  ┆ 0.23  ┆ 1     ┆ 0.8     ┆ {0.975657,0.953735} │
│ 0.22  ┆ -0.07 ┆ 0.44  ┆ 1     ┆ 0.57    ┆ {0.97898,0.909793}  │
└───────┴───────┴───────┴───────┴─────────┴─────────────────────┘

Finally, for convenience, in order to compute out-of-sample predictions you can use: least_squares.{predict, predict_from_formula}. This saves you the effort of un-nesting the coefficients and doing the dot product in python and instead does this in Rust, as an expression. Usage is as follows:

df_test.select(pl.col("coefficients_train").least_squares.predict(pl.col("x1"), pl.col("x2")).alias("predictions_test"))

Supported Models

Currently, this extension package supports the following variants:

  • Ordinary Least Squares: least_squares.ols
  • Weighted Least Squares: least_squares.wls
  • Regularized Least Squares (Lasso / Ridge / Elastic Net) least_squares.{lasso, ridge, elastic_net}
  • Non-negative Least Squares: least_squares.nnls

As well as efficient implementations of moving window models:

  • Recursive Least Squares: least_squares.rls
  • Rolling / Expanding Window OLS: least_squares.{rolling_ols, expanding_ols}

An arbitrary combination of sample_weights, L1/L2 penalties, and non-negativity constraints can be specified with the least_squares.from_formula and least_squares.least_squares entry-points.

Solve Methods

polars-ols provides a choice over multiple supported numerical approaches per model (via solve_method flag), with implications on performance vs numerical accuracy. These choices are exposed to the user for full control, however, if left unspecified the package will choose a reasonable default depending on context.

For example, if you know you are dealing with highly collinear data, with unregularized OLS model, you may want to explicitly set solve_method="svd" so that the minimum norm solution is obtained.

Benchmark

The usual caveats of benchmarks apply here, but the below should still be indicative of the type of performance improvements to expect when using this package.

This benchmark was run on randomly generated data with pyperf on my Apple M2 Max macbook (32GB RAM, MacOS Sonoma 14.2.1). See benchmark.py for implementation.

n_samples=2_000, n_features=5

Model polars_ols Python Benchmark Benchmark Type Speed-up vs Python Benchmark
Least Squares (QR) 300 us ± 7 us 1.01 ms ± 0.81 ms Numpy (QR) 3.4x
Least Squares (SVD) 351 us ± 4 us 853 us ± 417 us Numpy (SVD) 2.4x
Ridge (Cholesky) 279 us ± 6 us 1.63 ms ± 0.69 ms Sklearn (Cholesky) 5.8x
Ridge (SVD) 351 us ± 5 us 1.95 ms ± 1.12 ms Sklearn (SVD) 5.6x
Weighted Least Squares 531 us ± 4 us 2.54 ms ± 0.40 ms Statsmodels 4.8x
Elastic Net (CD) 339 us ± 5 us 2.17 ms ± 0.77 ms Sklearn 6.4x
Recursive Least Squares 1.42 ms ± 0.02 ms 18.5 ms ± 1.4 ms Statsmodels 13.0x
Rolling Least Squares 2.78 ms ± 0.07 ms 22.8 ms ± 0.2 ms Statsmodels 8.2x

n_samples=10_000, n_features=100

Model polars_ols Python Benchmark Benchmark Type Speed-up vs Python Benchmark
Least Squares (QR) 12.4 ms ± 0.2 ms 68.3 ms ± 13.7 ms Numpy (QR) 5.5x
Least Squares (SVD) 14.5 ms ± 0.5 ms 44.9 ms ± 10.3 ms Numpy (SVD) 3.1x
Ridge (Cholesky) 6.10 ms ± 0.14 ms 9.91 ms ± 2.86 ms Sklearn (Cholesky) 1.6x
Ridge (SVD) 24.9 ms ± 2.1 ms 390 ms ± 63 ms Sklearn (SVD) 15.7x
Weighted Least Squares 14.8 ms ± 2.4 ms 114 ms ± 35 ms Statsmodels 7.7x
Elastic Net (CD) 21.7 ms ± 1.2 ms 111 ms ± 54 ms Sklearn 5.1x
Recursive Least Squares 163 ms ± 28 ms 65.7 sec ± 28.2 sec Statsmodels 403.1x
Rolling Least Squares 390 ms ± 10 ms 3.99 sec ± 0.54 sec Statsmodels 10.2x
  • Numpy's lstsq (uses divide-and-conquer SVD) is already a highly optimized call into LAPACK and so the scope for speed-up is relatively limited, and the same applies to simple approaches like directly solving normal equations with Cholesky.
  • However, even in such problems polars-ols Rust implementations for matching numerical algorithms tend to outperform by ~2-3x
  • More substantial speed-up is achieved for the more complex models by working entirely in rust and avoiding overhead from back and forth into python.
  • Expect a large additional relative order-of-magnitude speed up to your workflow if it involved repeated re-estimation of models in (python) loops.

Credits & Related Projects

  • Rust linear algebra libraries faer and ndarray support the implementations provided by this extension package
  • This package was templated around the very helpful: polars-plugin-tutorial
  • The python package patsy is used for (optionally) building models from formulae
  • Please check out the extension package polars-ds for general data-science functionality in polars

Future Work / TODOs

  • Support generic types, in rust implementations, so that both f32 and f64 types are recognized. Right now data is cast to f32 prior to estimation
  • Add more detailed documentation on supported models, signatures, and API

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

polars_ols-0.2.8.tar.gz (61.1 kB view details)

Uploaded Source

Built Distributions

polars_ols-0.2.8-cp38-abi3-win_amd64.whl (4.2 MB view details)

Uploaded CPython 3.8+ Windows x86-64

polars_ols-0.2.8-cp38-abi3-win32.whl (3.6 MB view details)

Uploaded CPython 3.8+ Windows x86

polars_ols-0.2.8-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.0 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ x86-64

polars_ols-0.2.8-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl (8.1 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ i686

polars_ols-0.2.8-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl (7.2 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ ARMv7l

polars_ols-0.2.8-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (7.3 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ ARM64

polars_ols-0.2.8-cp38-abi3-macosx_11_0_arm64.whl (3.9 MB view details)

Uploaded CPython 3.8+ macOS 11.0+ ARM64

polars_ols-0.2.8-cp38-abi3-macosx_10_12_x86_64.whl (4.4 MB view details)

Uploaded CPython 3.8+ macOS 10.12+ x86-64

File details

Details for the file polars_ols-0.2.8.tar.gz.

File metadata

  • Download URL: polars_ols-0.2.8.tar.gz
  • Upload date:
  • Size: 61.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: maturin/1.5.1

File hashes

Hashes for polars_ols-0.2.8.tar.gz
Algorithm Hash digest
SHA256 38fbd11bba6cdd91b5e16626a3f0ea40802c09d0aa14f4f058719180cd8cb8ac
MD5 c6e7a6db1f834be3970c0358b7f4f407
BLAKE2b-256 421e53475b4f4882d9aa1cb446097a3c318783a0dd88c0e99bb1f5ab9feb58b7

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.8-cp38-abi3-win_amd64.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.8-cp38-abi3-win_amd64.whl
Algorithm Hash digest
SHA256 975cb869ca53b08d9996cb6753f9f5cf14321cd98f84151bed2df78de181a4ef
MD5 f5001bbbdd9fb36d5c0d3dcdf72c0e56
BLAKE2b-256 c801af94597a1dc583956ddf1c152b17fd96a72b3064662f487063ad4e4ad63e

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.8-cp38-abi3-win32.whl.

File metadata

  • Download URL: polars_ols-0.2.8-cp38-abi3-win32.whl
  • Upload date:
  • Size: 3.6 MB
  • Tags: CPython 3.8+, Windows x86
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: maturin/1.5.1

File hashes

Hashes for polars_ols-0.2.8-cp38-abi3-win32.whl
Algorithm Hash digest
SHA256 70143b3ce7a1aaf0ab08bd29293a805b0002d69ccdedc8fe905b17bc72cfe045
MD5 93142b006902d3a260757b5ddc57d248
BLAKE2b-256 7d1a93651fdacdfc629133573bd922affb483c6660927183106bed63ec867732

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.8-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.8-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 8b185e1ccb9a6f1a71fc3b2f43dbb2027015fcf5d4165e7e98cd08e23eaa5163
MD5 1d26c74c33fe1c197addacf651f3e3a7
BLAKE2b-256 b784f95236be45bcb77857a9044882f69e53a56a5617f38c3c28ecb865ee1839

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.8-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.8-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl
Algorithm Hash digest
SHA256 d4a5a7c4c34a1825dd7347bcde9565f20759f757e383ca0c63defdde4a4bec45
MD5 786ef3167969ab1e095f27d1f2ce76cb
BLAKE2b-256 f0157dd6753f6a7f222d46f19254bc82e06c8ad6e3e4534cbec89dc86fb9d86a

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.8-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.8-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl
Algorithm Hash digest
SHA256 cbd6cf7dfe2b07c60be9350f2f84d1437dc72ada9f45a784ac300acd0a22f805
MD5 b37a98304c5757972d3f184b23cf3f4c
BLAKE2b-256 146a0ce14e3a9dad48c68aecb8a1e12014a8e36234bd71d5aae131955f0ecdf3

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.8-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.8-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 06263993a1c351705d1d5936e3160f518c28d6f4ba042e187c7ccfd82ebb009c
MD5 42192331c9f65a482f754a62737bb5bc
BLAKE2b-256 80647ef31b7542f55c7ba7fe60a73954fbeb541641f8f4dbb1cc07f5d070fdf4

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.8-cp38-abi3-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.8-cp38-abi3-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 cc0a25fe6669027e221291fe5a22cbb7647a0290a2d749efe2ebf6a36f2e59f8
MD5 d4b7b98f3230d75450a1bc5cfc25f0d2
BLAKE2b-256 392dd02cc06ba42478922b5f8fbfd36774e349cd172911376269c172b68ac7db

See more details on using hashes here.

File details

Details for the file polars_ols-0.2.8-cp38-abi3-macosx_10_12_x86_64.whl.

File metadata

File hashes

Hashes for polars_ols-0.2.8-cp38-abi3-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 3999d0a2f20ad62d3babf196902a8c1b8a6af69f0d6f60df446b62960d077ede
MD5 6c836aa823f54b5458c5976dd98080e2
BLAKE2b-256 1390c5342df2d94d3530d210c685bde81ab7bbfc60a1c6bdf7bf2badf79e693a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page