Skip to main content

Polars Least Squares Extension

Project description

Polars OLS

Least squares extension in Polars

Supports linear model estimation in Polars.

This package provides efficient rust implementations of common linear regression variants (OLS, WLS, Ridge, Elastic Net, Non-negative least squares, Recursive least squares) and exposes them as simple polars expressions which can easily be integrated into your workflow.

Why?

  1. High Performance: implementations are written in rust and make use of optimized rust linear-algebra crates & LAPACK routines. See benchmark section.
  2. Polars Integration: avoids unnecessary conversions from lazy to eager mode and to external libraries (e.g. numpy, sklearn) to do simple linear regressions. Chain least squares formulae like any other expression in polars.
  3. Efficient Implementations:
    • Numerically stable algorithms are chosen where appropriate (e.g. QR, Cholesky).
    • Flexible model specification allows arbitrary combination of sample weighting, L1/L2 regularization, & non-negativity constraints on parameters.
    • Efficient rank-1 update algorithms used for moving window regressions.
  4. Easy Parallelism: Computing OLS predictions, in parallel, across groups can not be easier: call .over() or group_by just like any other polars' expression and benefit from full Rust parallelism.
  5. Formula API: supports building models via patsy syntax: y ~ x1 + x2 + x3:x4 -1 (like statsmodels) which automatically converts to equivalent polars expressions.

Installation

First, you need to install Polars. Then run the below to install the polars-ols extension:

pip install polars-ols

API & Examples

Importing polars_ols will register the namespace least_squares provided by this package. You can build models either by either specifying polars expressions (e.g. pl.col(...)) for your targets and features or using the formula api (patsy syntax). All models support the following general (optional) arguments:

  • mode - a literal which determines the type of output produced by the model
  • null_policy - a literal which determines how to deal with missing data
  • add_intercept - a boolean specifying if an intercept feature should be added to the features
  • sample_weights - a column or expression providing non-negative weights applied to the samples

Remaining parameters are model specific, for example alpha penalty parameter used by regularized least squares models.

See below for basic usage examples. Please refer to the tests or demo notebook for detailed examples.

import polars as pl
import polars_ols as pls  # registers 'least_squares' namespace

df = pl.DataFrame({"y": [1.16, -2.16, -1.57, 0.21, 0.22, 1.6, -2.11, -2.92, -0.86, 0.47],
                   "x1": [0.72, -2.43, -0.63, 0.05, -0.07, 0.65, -0.02, -1.64, -0.92, -0.27],
                   "x2": [0.24, 0.18, -0.95, 0.23, 0.44, 1.01, -2.08, -1.36, 0.01, 0.75],
                   "group": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
                   "weights": [0.34, 0.97, 0.39, 0.8, 0.57, 0.41, 0.19, 0.87, 0.06, 0.34],
                   })

lasso_expr = pl.col("y").least_squares.lasso(pl.col("x1"), pl.col("x2"), alpha=0.0001, add_intercept=True).over("group")
wls_expr = pls.compute_least_squares_from_formula("y ~ x1 + x2 -1", sample_weights=pl.col("weights"))

predictions = df.with_columns(lasso_expr.round(2).alias("predictions_lasso"),
                              wls_expr.round(2).alias("predictions_wls"))

print(predictions.head(5))
shape: (5, 7)
┌───────┬───────┬───────┬───────┬─────────┬───────────────────┬─────────────────┐
│ y     ┆ x1    ┆ x2    ┆ group ┆ weights ┆ predictions_lasso ┆ predictions_wls │
│ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---     ┆ ---               ┆ ---             │
│ f64   ┆ f64   ┆ f64   ┆ i64   ┆ f64     ┆ f64               ┆ f64             │
╞═══════╪═══════╪═══════╪═══════╪═════════╪═══════════════════╪═════════════════╡
│ 1.16  ┆ 0.72  ┆ 0.24  ┆ 1     ┆ 0.34    ┆ 0.97              ┆ 0.93            │
│ -2.16 ┆ -2.43 ┆ 0.18  ┆ 1     ┆ 0.97    ┆ -2.23             ┆ -2.18           │
│ -1.57 ┆ -0.63 ┆ -0.95 ┆ 1     ┆ 0.39    ┆ -1.54             ┆ -1.54           │
│ 0.21  ┆ 0.05  ┆ 0.23  ┆ 1     ┆ 0.8     ┆ 0.29              ┆ 0.27            │
│ 0.22  ┆ -0.07 ┆ 0.44  ┆ 1     ┆ 0.57    ┆ 0.37              ┆ 0.36            │
└───────┴───────┴───────┴───────┴─────────┴───────────────────┴─────────────────┘

The mode parameter is used to set the type of the output returned by all methods ("predictions", "residuals", "coefficients"). It defaults to returning predictions matching the input's length.

In case "coefficients" is set the output is a polars Struct with coefficients as values and feature names as fields. It's output shape 'broadcasts' depending on context, see below:

coefficients = df.select(pl.col("y").least_squares.from_formula("x1 + x2", mode="coefficients")
                         .alias("coefficients"))

coefficients_group = df.select("group", pl.col("y").least_squares.from_formula("x1 + x2", mode="coefficients").over("group")
                        .alias("coefficients_group")).unique(maintain_order=True)

print(coefficients)
print(coefficients_group)
shape: (1, 1)
┌──────────────────────────────┐
│ coefficients                 │
│ ---                          │
│ struct[3]                    │
╞══════════════════════════════╡
│ {0.977375,0.987413,0.000757} │  # <--- coef for x1, x2, and intercept added by formula API
└──────────────────────────────┘
shape: (2, 2)
┌───────┬───────────────────────────────┐
│ group ┆ coefficients_group            │
│ ---   ┆ ---                           │
│ i64   ┆ struct[3]                     │
╞═══════╪═══════════════════════════════╡
│ 1     ┆ {0.995157,0.977495,0.014344}  │
│ 2     ┆ {0.939217,0.997441,-0.017599} │  # <--- (unique) coefficients per group
└───────┴───────────────────────────────┘

For dynamic models (like rolling_ols) or if in a .over, .group_by, or .with_columns context, the coefficients will take the shape of the data it is applied on. For example:

coefficients = df.with_columns(pl.col("y").least_squares.rls(pl.col("x1"), pl.col("x2"), mode="coefficients")
                         .over("group").alias("coefficients"))

print(coefficients.head())
shape: (5, 6)
┌───────┬───────┬───────┬───────┬─────────┬─────────────────────┐
│ y     ┆ x1    ┆ x2    ┆ group ┆ weights ┆ coefficients        │
│ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---     ┆ ---                 │
│ f64   ┆ f64   ┆ f64   ┆ i64   ┆ f64     ┆ struct[2]           │
╞═══════╪═══════╪═══════╪═══════╪═════════╪═════════════════════╡
│ 1.16  ┆ 0.72  ┆ 0.24  ┆ 1     ┆ 0.34    ┆ {1.235503,0.411834} │
│ -2.16 ┆ -2.43 ┆ 0.18  ┆ 1     ┆ 0.97    ┆ {0.963515,0.760769} │
│ -1.57 ┆ -0.63 ┆ -0.95 ┆ 1     ┆ 0.39    ┆ {0.975484,0.966029} │
│ 0.21  ┆ 0.05  ┆ 0.23  ┆ 1     ┆ 0.8     ┆ {0.975657,0.953735} │
│ 0.22  ┆ -0.07 ┆ 0.44  ┆ 1     ┆ 0.57    ┆ {0.97898,0.909793}  │
└───────┴───────┴───────┴───────┴─────────┴─────────────────────┘

Finally, for convenience, in order to compute out-of-sample predictions you can use: least_squares.{predict, predict_from_formula}. This saves you the effort of un-nesting the coefficients and doing the dot product in python and instead does this in Rust, as an expression. Usage is as follows:

df_test.select(pl.col("coefficients_train").least_squares.predict(pl.col("x1"), pl.col("x2")).alias("predictions_test"))

Supported Models

Currently, this extension package supports the following variants:

  • Ordinary Least Squares: least_squares.ols
  • Weighted Least Squares: least_squares.wls
  • Regularized Least Squares (Lasso / Ridge / Elastic Net) least_squares.{lasso, ridge, elastic_net}
  • Non-negative Least Squares: least_squares.nnls

As well as efficient implementations of moving window models:

  • Recursive Least Squares: least_squares.rls
  • Rolling / Expanding Window OLS: least_squares.{rolling_ols, expanding_ols}

An arbitrary combination of sample_weights, L1/L2 penalties, and non-negativity constraints can be specified with the least_squares.from_formula and least_squares.least_squares entry-points.

Solve Methods

polars-ols provides a choice over multiple supported numerical approaches per model (via solve_method flag), with implications on performance vs numerical accuracy. These choices are exposed to the user for full control, however, if left unspecified the package will choose a reasonable default depending on context.

For example, if you know you are dealing with highly collinear data, with unregularized OLS model, you may want to explicitly set solve_method="svd" so that the minimum norm solution is obtained.

Benchmark

The usual caveats of benchmarks apply here, but the below should still be indicative of the type of performance improvements to expect when using this package.

This benchmark was run on randomly generated data with pyperf on my Apple M2 Max macbook (32GB RAM, MacOS Sonoma 14.2.1). See benchmark.py for implementation.

n_samples=2_000, n_features=5

Model polars_ols Python Benchmark Benchmark Type Speed-up vs Python Benchmark
Least Squares (QR) 195 µs ± 6 µs 466 µs ± 104 µs Numpy (QR) 2.4x
Least Squares (SVD) 247 µs ± 5 µs 395 µs ± 69 µs Numpy (SVD) 1.6x
Ridge (Cholesky) 171 µs ± 8 µs 1.02 ms ± 0.29 ms Sklearn (Cholesky) 5.9x
Ridge (SVD) 238 µs ± 7 µs 1.12 ms ± 0.41 ms Sklearn (SVD) 4.7x
Weighted Least Squares 334 µs ± 13 µs 2.04 ms ± 0.22 ms Statsmodels 6.1x
Elastic Net (CD) 227 µs ± 7 µs 1.18 ms ± 0.19 ms Sklearn 5.2x
Recursive Least Squares 1.12 ms ± 0.23 ms 18.2 ms ± 1.6 ms Statsmodels 16.2x
Rolling Least Squares 1.99 ms ± 0.03 ms 22.1 ms ± 0.2 ms Statsmodels 11.1x

n_samples=10_000, n_features=100

Model polars_ols Python Benchmark Benchmark Type Speed-up vs Python Benchmark
Least Squares (QR) 17.6 ms ± 0.3 ms 44.4 ms ± 9.3 ms Numpy (QR) 2.5x
Least Squares (SVD) 23.8 ms ± 0.2 ms 26.6 ms ± 5.5 ms Numpy (SVD) 1.1x
Ridge (Cholesky) 5.36 ms ± 0.16 ms 475 ms ± 71 ms Sklearn (Cholesky) 88.7x
Ridge (SVD) 30.2 ms ± 0.4 ms 400 ms ± 48 ms Sklearn (SVD) 13.2x
Weighted Least Squares 18.8 ms ± 0.3 ms 80.4 ms ± 12.4 ms Statsmodels 4.3x
Elastic Net (CD) 22.7 ms ± 0.2 ms 138 ms ± 27 ms Sklearn 6.1x
Recursive Least Squares 270 ms ± 53 ms 57.8 sec ± 43.7 sec Statsmodels 1017.0x
Rolling Least Squares 371 ms ± 13 ms 4.41 sec ± 0.17 sec Statsmodels 11.9x
  • Numpy's lstsq (uses divide-and-conquer SVD) is already a highly optimized call into LAPACK and so the scope for speed-up is relatively limited, and the same applies to simple approaches like directly solving normal equations with Cholesky.
  • However, even in such problems polars-ols Rust implementations for matching numerical algorithms tend to outperform by ~2-3x
  • More substantial speed-up is achieved for the more complex models by working entirely in rust and avoiding overhead from back and forth into python.
  • Expect a large additional relative order-of-magnitude speed up to your workflow if it involved repeated re-estimation of models in (python) loops.

Credits & Related Projects

  • Rust linear algebra libraries faer and ndarray support the implementations provided by this extension package
  • This package was templated around the very helpful: polars-plugin-tutorial
  • The python package patsy is used for (optionally) building models from formulae
  • Please check out the extension package polars-ds for general data-science functionality in polars

Future Work / TODOs

  • Support generic types, in rust implementations, so that both f32 and f64 types are recognized. Right now data is cast to f64 prior to estimation
  • Add docs explaining supported models, signatures, and API

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

polars_ols-0.3.2.tar.gz (80.7 kB view details)

Uploaded Source

Built Distributions

polars_ols-0.3.2-cp38-abi3-win_amd64.whl (10.2 MB view details)

Uploaded CPython 3.8+ Windows x86-64

polars_ols-0.3.2-cp38-abi3-win32.whl (8.8 MB view details)

Uploaded CPython 3.8+ Windows x86

polars_ols-0.3.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.6 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ x86-64

polars_ols-0.3.2-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl (13.4 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ i686

polars_ols-0.3.2-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl (12.1 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ ARMv7l

polars_ols-0.3.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (12.0 MB view details)

Uploaded CPython 3.8+ manylinux: glibc 2.17+ ARM64

polars_ols-0.3.2-cp38-abi3-macosx_11_0_arm64.whl (10.1 MB view details)

Uploaded CPython 3.8+ macOS 11.0+ ARM64

polars_ols-0.3.2-cp38-abi3-macosx_10_12_x86_64.whl (11.1 MB view details)

Uploaded CPython 3.8+ macOS 10.12+ x86-64

File details

Details for the file polars_ols-0.3.2.tar.gz.

File metadata

  • Download URL: polars_ols-0.3.2.tar.gz
  • Upload date:
  • Size: 80.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: maturin/1.5.1

File hashes

Hashes for polars_ols-0.3.2.tar.gz
Algorithm Hash digest
SHA256 68ab652f7069f34cbab55c322c94d62f3dd40ebe980b96f4751a09738d389bf0
MD5 f810ba74a865c4585335581d5b60a632
BLAKE2b-256 4d6ee8960b61ffbee88760768cfc9bad9a49fcbc7bae6401719654d78ced5082

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.2-cp38-abi3-win_amd64.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.2-cp38-abi3-win_amd64.whl
Algorithm Hash digest
SHA256 3ed81fe078d3853e1642be72b7bb1403969c8605776c8e61346cd6e76cc2b30e
MD5 ba7589132ce9db1918d1e31d7d7fcd49
BLAKE2b-256 4a116f516d219f38f54aa9aafe66673799a4cb1ea4457d436cd7e6620a2a1ac7

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.2-cp38-abi3-win32.whl.

File metadata

  • Download URL: polars_ols-0.3.2-cp38-abi3-win32.whl
  • Upload date:
  • Size: 8.8 MB
  • Tags: CPython 3.8+, Windows x86
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: maturin/1.5.1

File hashes

Hashes for polars_ols-0.3.2-cp38-abi3-win32.whl
Algorithm Hash digest
SHA256 d6de8e8a912764d862ddd6f20c9de7bcb1158dfa3e4cbaca39808fbc581dbc60
MD5 083aa0d8ee4a248f8ef0113c2683ea70
BLAKE2b-256 3ffea506567f6516718d04a079a0c9b3ee0b737fafa42d960be626d79584ee6f

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 fdbfe9d7779775325506e4748c5b8c5f201eeaf98bd12bd300d94cf2c2301ee2
MD5 ea1b109e2f6d89f4abf3a9fcfec2fe89
BLAKE2b-256 167b8806009706e6c6fa75cd437f561c7674764da15805eae03fc0724d9be38d

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.2-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.2-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl
Algorithm Hash digest
SHA256 3f7e04f9d39c44b10df477e47a9eef9bc8809d75e313cf277ae89ea21f8e43e3
MD5 ace652f35b8deb0707263b92e5c188ce
BLAKE2b-256 87fc6263b436297f463902c44a91d5b72968de2bf71870c8cbe83e2357ee8130

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.2-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.2-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl
Algorithm Hash digest
SHA256 53b93da638978fd802af136b5e288e74f084ea3eac35163cee72467bf8308989
MD5 3845eef60b0d3dddc0777aebfffec1d6
BLAKE2b-256 51524156195eb48f31029273c94f88ed8acfff9d1fd76bb7f963788b622c9cbe

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 dce9a8d3453d61b02fc9190162a502c076a81ea3144656c3278dea813b673a02
MD5 2d0e0d0e92a42d2493bd5cf6258c9f6a
BLAKE2b-256 aaa9610c569ea599236f9d7d7009e4d8b8560e21cec3c13c701f0b62202f7974

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.2-cp38-abi3-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.2-cp38-abi3-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 a6d5557e6ec2f486b30986f9b4d9867d75539d87cba04238fe9f3030830171de
MD5 a933004c94d6d69cfff4840eca0b303b
BLAKE2b-256 9125163575bc9a4f477560f86ad97d6e56312f20d9ab884e0dad44ad73c1787d

See more details on using hashes here.

File details

Details for the file polars_ols-0.3.2-cp38-abi3-macosx_10_12_x86_64.whl.

File metadata

File hashes

Hashes for polars_ols-0.3.2-cp38-abi3-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 8b0787df4701a44dac3f2fc036e817c6b1809c84555edd28575e8cfb70c394f0
MD5 115ab98756f839672af7c4f97bcb3cec
BLAKE2b-256 9a1fcc9d09fd02867808d69cf3f44c2491f7970bc0b7f959fd67424fc85a0e05

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page