A library of useful modules for data analysis.
Project description
torch_linear_regression
A very simple library containing closed-form linear regression models using PyTorch. Includes:
- Ordinary Least Squares (OLS) Linear Regression:
(X'X)^-1 X'Y
- Ridge Regression:
(X'X + λI)^-1 X'Y
- Reduced Rank Regression (RRR) with Ridge penalty: Ridge regression followed by SVD on the weights matrix
The closed-form approach results in fast and accurate results under most
conditions. However, when n_features
is large and/or underdetermined
(n_samples
<= n_features
), the solution will start to diverge from
gradient-based / sklearn solutions.
Each model also includes a model.prefit()
method that can be used to precompute
the inverse matrix and the ridge penalty matrix. This can be useful when the model
is used multiple times with the same X
input data.
Because the models are based on PyTorch, they are significantly faster than sklearn's models, and can be further accelerated by using GPU. Also the models can be used in conjunction with PyTorch's autograd for gradient-based optimization.
Installation
Install stable version:
pip install torch_linear_regression
Install development version:
pip install git+https://github.com/RichieHakim/torch_linear_regression.git
Usage
See the notebook for more examples: demo notebook
import torch_linear_regression as tlr
import torch
import numpy as np
import sklearn
import sklearn.datasets
import matplotlib.pyplot as plt
## Generate data for regression
X, Y = sklearn.datasets.make_regression(
n_samples=100,
n_features=2,
n_informative=10,
bias=2,
noise=50,
random_state=42,
)
## Create model
model_ols = tlr.OLS()
## Fit model
model_ols.fit(X=X, y=Y)
## Predict
Y_pred = model_ols.predict(X)
## Score model
score = model_ols.score(X=X, y=Y)
print(f"R^2: {score}")
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch_linear_regression-0.1.0.tar.gz
.
File metadata
- Download URL: torch_linear_regression-0.1.0.tar.gz
- Upload date:
- Size: 9.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.0 CPython/3.12.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7e6e2bd9bd67073946c73af34705aee093964ef32b1e91f7b23fb30ab58cdbfc |
|
MD5 | 822a2bbebfb3861af2592e8ea4430b3c |
|
BLAKE2b-256 | 6d27fb369afbdf41f4f153dcb2d902b0f275f31d5594464253f8728823fec8a5 |
File details
Details for the file torch_linear_regression-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: torch_linear_regression-0.1.0-py3-none-any.whl
- Upload date:
- Size: 8.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.0 CPython/3.12.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2a123359f7df08e595babd664dbc4e30fd93330af286c0f752740c4ad460c9cb |
|
MD5 | 7f40b75f3c975f30fd98f7db8fa9be31 |
|
BLAKE2b-256 | ae21fe7394c25e961dadd16f1d006e1e71e2cc31755a13f94d3eae33e46ba229 |