A library of useful modules for data analysis.
Project description
torch_linear_regression
A very simple library containing closed-form linear regression models using PyTorch. Includes:
- Ordinary Least Squares (OLS) Linear Regression:
(X'X)^-1 X'Y - Ridge Regression:
(X'X + λI)^-1 X'Y - Reduced Rank Regression (RRR) with Ridge penalty: Ridge regression followed by SVD on the weights matrix
The closed-form approach results in fast and accurate results under most
conditions. However, when n_features is large and/or underdetermined
(n_samples <= n_features), the solution will start to diverge from
gradient-based / sklearn solutions.
Each model also includes a model.prefit() method that can be used to precompute
the inverse matrix and the ridge penalty matrix. This can be useful when the model
is used multiple times with the same X input data.
Because the models are based on PyTorch, they are significantly faster than sklearn's models, and can be further accelerated by using GPU. Also the models can be used in conjunction with PyTorch's autograd for gradient-based optimization.
Installation
Install stable version:
pip install torch_linear_regression
Install development version:
pip install git+https://github.com/RichieHakim/torch_linear_regression.git
Usage
See the notebook for more examples: demo notebook
import torch_linear_regression as tlr
import torch
import numpy as np
import sklearn
import sklearn.datasets
import matplotlib.pyplot as plt
## Generate data for regression
X, Y = sklearn.datasets.make_regression(
n_samples=100,
n_features=2,
n_informative=10,
bias=2,
noise=50,
random_state=42,
)
## Create model
model_ols = tlr.OLS()
## Fit model
model_ols.fit(X=X, y=Y)
## Predict
Y_pred = model_ols.predict(X)
## Score model
score = model_ols.score(X=X, y=Y)
print(f"R^2: {score}")
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_linear_regression-0.1.0.tar.gz.
File metadata
- Download URL: torch_linear_regression-0.1.0.tar.gz
- Upload date:
- Size: 9.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.0 CPython/3.12.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7e6e2bd9bd67073946c73af34705aee093964ef32b1e91f7b23fb30ab58cdbfc
|
|
| MD5 |
822a2bbebfb3861af2592e8ea4430b3c
|
|
| BLAKE2b-256 |
6d27fb369afbdf41f4f153dcb2d902b0f275f31d5594464253f8728823fec8a5
|
File details
Details for the file torch_linear_regression-0.1.0-py3-none-any.whl.
File metadata
- Download URL: torch_linear_regression-0.1.0-py3-none-any.whl
- Upload date:
- Size: 8.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.0 CPython/3.12.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2a123359f7df08e595babd664dbc4e30fd93330af286c0f752740c4ad460c9cb
|
|
| MD5 |
7f40b75f3c975f30fd98f7db8fa9be31
|
|
| BLAKE2b-256 |
ae21fe7394c25e961dadd16f1d006e1e71e2cc31755a13f94d3eae33e46ba229
|