Skip to main content

A library of useful modules for data analysis.

Project description

PyPI version Downloads repo size

torch_linear_regression

A very simple library containing closed-form linear regression models using PyTorch. Includes:

  • Ordinary Least Squares (OLS) Linear Regression: (X'X)^-1 X'Y
  • Ridge Regression: (X'X + λI)^-1 X'Y
  • Reduced Rank Regression (RRR) with Ridge penalty: Ridge regression followed by SVD on the weights matrix

The closed-form approach results in fast and accurate results under most conditions. However, when n_features is large and/or underdetermined (n_samples <= n_features), the solution will start to diverge from gradient-based / sklearn solutions.

Each model also includes a model.prefit() method that can be used to precompute the inverse matrix and the ridge penalty matrix. This can be useful when the model is used multiple times with the same X input data.

Because the models are based on PyTorch, they are significantly faster than sklearn's models, and can be further accelerated by using GPU. Also the models can be used in conjunction with PyTorch's autograd for gradient-based optimization.

Installation

Install stable version:

pip install torch_linear_regression

Install development version:

pip install git+https://github.com/RichieHakim/torch_linear_regression.git

Usage

See the notebook for more examples: demo notebook

import torch_linear_regression as tlr

import torch
import numpy as np
import sklearn
import sklearn.datasets
import matplotlib.pyplot as plt


## Generate data for regression
X, Y = sklearn.datasets.make_regression(
    n_samples=100,
    n_features=2,
    n_informative=10,
    bias=2,
    noise=50,
    random_state=42,
)

## Create model
model_ols = tlr.OLS()
## Fit model
model_ols.fit(X=X, y=Y)
## Predict
Y_pred = model_ols.predict(X)
## Score model
score = model_ols.score(X=X, y=Y)
print(f"R^2: {score}")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_linear_regression-0.1.0.tar.gz (9.8 kB view details)

Uploaded Source

Built Distribution

torch_linear_regression-0.1.0-py3-none-any.whl (8.5 kB view details)

Uploaded Python 3

File details

Details for the file torch_linear_regression-0.1.0.tar.gz.

File metadata

File hashes

Hashes for torch_linear_regression-0.1.0.tar.gz
Algorithm Hash digest
SHA256 7e6e2bd9bd67073946c73af34705aee093964ef32b1e91f7b23fb30ab58cdbfc
MD5 822a2bbebfb3861af2592e8ea4430b3c
BLAKE2b-256 6d27fb369afbdf41f4f153dcb2d902b0f275f31d5594464253f8728823fec8a5

See more details on using hashes here.

File details

Details for the file torch_linear_regression-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_linear_regression-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2a123359f7df08e595babd664dbc4e30fd93330af286c0f752740c4ad460c9cb
MD5 7f40b75f3c975f30fd98f7db8fa9be31
BLAKE2b-256 ae21fe7394c25e961dadd16f1d006e1e71e2cc31755a13f94d3eae33e46ba229

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page