Skip to main content

A library of useful modules for data analysis.

Project description

PyPI version Downloads repo size

torch_linear_regression

A very simple library containing closed-form linear regression models using PyTorch. Includes:

  • Ordinary Least Squares (OLS) Linear Regression: (X'X)^-1 X'Y
  • Ridge Regression: (X'X + λI)^-1 X'Y
  • Reduced Rank Regression (RRR) with Ridge penalty: Ridge regression followed by SVD on the weights matrix

The closed-form approach results in fast and accurate results under most conditions. However, when n_features is large and/or underdetermined (n_samples <= n_features), the solution will start to diverge from gradient-based / sklearn solutions.

Each model also includes a model.prefit() method that can be used to precompute the inverse matrix and the ridge penalty matrix. This can be useful when the model is used multiple times with the same X input data.

Because the models are based on PyTorch, they are significantly faster than sklearn's models, and can be further accelerated by using GPU. Also the models can be used in conjunction with PyTorch's autograd for gradient-based optimization.

Installation

Install stable version:

pip install torch_linear_regression

Install development version:

pip install git+https://github.com/RichieHakim/torch_linear_regression.git

Usage

See the notebook for more examples: demo notebook

import torch_linear_regression as tlr

import torch
import numpy as np
import sklearn
import sklearn.datasets
import matplotlib.pyplot as plt


## Generate data for regression
X, Y = sklearn.datasets.make_regression(
    n_samples=100,
    n_features=2,
    n_informative=10,
    bias=2,
    noise=50,
    random_state=42,
)

## Create model
model_ols = tlr.OLS()
## Fit model
model_ols.fit(X=X, y=Y)
## Predict
Y_pred = model_ols.predict(X)
## Score model
score = model_ols.score(X=X, y=Y)
print(f"R^2: {score}")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_linear_regression-0.1.0.tar.gz (9.8 kB view hashes)

Uploaded Source

Built Distribution

torch_linear_regression-0.1.0-py3-none-any.whl (8.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page