No project description provided
Project description
RandALO: fast randomized risk estimation for high-dimensional data
This repository contains a software package implementing RandALO, a fast randomized method for risk estimation of machine learning models, as described in the paper,
P. T. Nobel, D. LeJeune, E. J. Candès. RandALO: Out-of-sample risk estimation in no time flat. 2024.
Installation
In a folder run the following:
git clone git@github.com:cvxgrp/randalo.git
cd randalo
# create a new environment with Python >= 3.10 (could also use venv or similar)
conda create -n randalo python=3.12
# install requirements and randalo
pip install -r requirements.txt
Usage
Scikit-learn
The simplest way to use RandALO is with linear models from scikit-learn. See a longer demonstration in a notebook here.
from torch import nn
from sklearn.linear_model import Lasso
from randalo import RandALO
X, y = ... # load data as np.ndarrays as usual
model = Lasso(1.0).fit(X, y) # fit the model
alo = RandALO.from_sklearn(model, X, y) # set up the Jacobian
mse_estimate = alo.evaluate(nn.MSELoss()) # estimate risk
We currently support the following models:
LinearRegression
Ridge
Lasso
LassoLars
ElasticNet
LogisticRegression
Linear models with any solver
If you prefer to use other solvers for fitting your models than scikit-learn, or if you wish to extend to other models than the ones listed above, you can still use RandALO by instantiating the Jacobian yourself. You need only be careful to ensure that you scale the regularizer correctly for your problem formulation.
from torch import nn
from sklearn.linear_model import Lasso
from randalo import RandALO, MSELoss, L1Regularizer, Jacobian
X, y = ... # load data as np.ndarrays as usual
model = Lasso(1.0).fit(X, y) # fit the model
# instantiate RandALO by creating a Jacobian object
loss = MSELoss()
reg = 2.0 * model.alpha * L1Regularizer() # scale the regularizer appropriately
y_hat = model.predict(X)
solution_func = lambda: model.coef_
jac = Jacobian(y, X, solution_func, loss, reg)
alo = RandALO(loss, jac, y, y_hat)
mse_estimate = alo.evaluate(nn.MSELoss()) # estimate risk
Please refer to our scikit-learn integration source code for more examples.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file randalo-0.1.0.tar.gz
.
File metadata
- Download URL: randalo-0.1.0.tar.gz
- Upload date:
- Size: 17.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d90df1ad3fb491b87aba1dc0cdaccce33f92f9c77b5e071cb68df3b9e11f2a1d |
|
MD5 | 97ce5d3d88f9abb441e12d13cdba6336 |
|
BLAKE2b-256 | 57ed6aff23e40734aa4e279ff4543b1dcf53e0eecdbee7c08018df75c4aef33b |
File details
Details for the file randalo-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: randalo-0.1.0-py3-none-any.whl
- Upload date:
- Size: 20.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9e99445c731aefc4cf8fc1b3e51d5c8af214001d4d59f18845425bff2c3fb60c |
|
MD5 | f066bab562d777c9593c3081cb5c2816 |
|
BLAKE2b-256 | 0886da7331cf8cf67639455bd1071e2c904c212c305523b3a93e9e887ee137e0 |