This package enables tractable evaluation of Stein's Unbiased Risk Estimate on convexly regularized estimators.
Project description
SURE-CR
This package enables tractable evaluation of Stein's Unbiased Risk Estimate on convexly regularized estimators.
For an estimator of the mean of a normally-distributed random vector $y$ with known covariance matrix $\sigma^2 I$ given by
$$ \hat\mu(y) = \mathcal A ~ \text{argmin} \frac{1}{2} \lVert\mathcal A b - y\rVert_2^2 + r(b) $$
where $r: \mathbb R^p \to \mathbb R$ is a convex function and $\mathcal A: \mathbb R^p \to \mathbb R^d$ is a linear operator, this package provides methods to compute Stein's Unbiased Risk Estimate of $\hat\mu$:
$$ SURE(\hat\mu, y) = -n \sigma^2 + \lVert\hat\mu(y) - y\rVert_2^2 + 2 \sigma^2 \nabla \cdot \hat\mu(y). $$
$SURE(\hat\mu, y)$ is a good estimate of the $\ell_2$ risk of $\hat\mu$, especially for high dimensional problems.
Installation
SURE-CR can be installed with pip and conda. By default, the conda instillation does not support using the CVXPYSolver
, and the cvxpylayers
library will have to be installed seperately to use it.
To install with pip:
$ pip install surecr
To install with conda:
$ conda install -c stanfordcvxgrp -c conda-forge sure-cr
Examples
The easiest way to start using SURE-CR may be to read the examples:
- LASSO with
FISTASolver
- Shows how to define proximal operator.
- LASSO with
CVXPYSolver
- Matrix completion with
ADMMSolver
- Includes defining custom linear operator.
- Robust PCA with
ADMMSolver
- Example of separable regularizers.
Usage
There are three key things in this package:
- The
SURE
class - The
Solver
class and its subclassesCVXPYSolver
,FISTASolver
, andADMMSolver
- The
prox_lib
helper library
The SURE
Class
The SURE
class has the following API:
class SURE:
def __init__(self, variance: float, solver: Solver): ...
def compute(self, y: torch.Tensor, divergence_parameters={}) -> float:
"""
Computes and returns SURE for the estimator computed by the solver
at the point y.
Currently, divergence_parameters can contain the key "m" to indicate
how many samples to use during the divergence estimation (which
dominates the runtime at high dimensions). The default is for m to be
102.
In the future we may switch to A-Hutch++ and may change what options
the divergence_parameters specifies.
"""
@property
def solution(self) -> torch.Tensor:
"""
Returns solver.solve(y) from the last compute call.
"""
def runtimes(self) -> TypedDict('Runtimes', solver=float, divergence=float):
"""
Returns how long it took for the solver to run and how long it took
the divergence estimator to run during the last compute call.
"""
The Solver
class
Most uses of the library should use one of the existing Solver
subclasses.
They have the following APIs:
The three notable Solver
instances provided by this library have the following
constructors:
class FISTASolver(Solver):
def __init__(self, A: linops.LinearOperator,
prox_R: Callable[[torch.Tensor, float | torch.Tensor], torch.Tensor],
x0: torch.Tensor,
device=None,
lipschitz_iterations=20,
lipschitz_vec=None,
*, max_iters=5000, eps=1e-3):
"""
This solver solves problems of the form with a variant on FISTA:
min. 1/2 ||A b - y||_2^2 + r(b)
and estimates the mean of y with A b^* where b^* is the optimal b.
A is a linear operator defined using <https://github.com/cvxgrp/torch_linops>
prox_R is a differentiable-with-respect-to-its-first-argument function to
find the optimal point b for a (v, t) pair of
min. t r(b) + 1/2 ||b - v||_2^2
x0 is the point where we begin iterations, it must be chosen
indepentently of y.
lipschitz_iterations is how many iterations of the power method to use
to approximate the largest eigenvalue of A^T A
lipschitz_vec is the vector to start the power method. By default, a
vector of all 1s is used. If this vector is orthogonal to the largest
eigenvector of A^T A, this argument is mandatory.
max_iters, eps control when iterations stop.
"""
class ADMMSolver(Solver):
def __init__(self, A: linops.LinearOperator,
prox_R: Callable[[torch.Tensor, float | torch.Tensor], torch.Tensor],
x0: torch.Tensor,
device=None,
*, max_iters=1000, eps_rel=1e-3, eps_abs=1e-6):
"""
This solver solves problems of the form with a variant on ADMM:
min. 1/2 ||A b - y||_2^2 + r(b)
and estimates the mean of y with A b^* where b^* is the optimal b.
A is a linear operator defined using <https://github.com/cvxgrp/torch_linops>
prox_R is a differentiable-with-respect-to-its-first-argument function to
find the optimal point b for a (v, t) pair of
min. t r(b) + 1/2 ||b - v||_2^2
x0 is the point where we begin iterations, it must be chosen
indepentently of y.
max_iters, eps_rel, eps_abs control when iterations stop.
"""
class CVXPYSolver(Solver):
def __init__(self, problem: cp.Problem,
y_parameter: cp.Parameter,
variables: list[cp.Variable],
estimate: Callable[[list[torch.Tensor]], torch.Tensor]):
"""
problem must be a CVXPY problem with a single paremeter, y_parameter,
and variables y_variable.
estimate must be function which takes tensors with values for each variable
and returns the estimate.
WARNING: This solver has poor performance on large problems, and can
have undetected poor accuracy on some moderately-sized problems.
"""
If you wish to implement, Solver
, it has has the following API, where T
is
any type of the implementation's choice:
class Solver:
def solve(self, y: torch.Tensor) -> T:
"""
Returns intermediate value used to estimate the mean of the distribution
y is sampled from.
"""
def estimate(self, beta: T) -> torch.Tensor: ...
"""
Given the output of a solve call, returns the estimate of the mean of the
distribution y was sampled from.
"""
Note that for a given instance s
of a solver class, s.estimate(s.solve(y))
must
be differentiable via torch's backpropagation.
The prox_lib
library
Since FISTASolver
and ADMMSolver
both require a proximal operator for the
regularizer we provide some methods here to help construct proximal operators:
There are also many helper methods in surecr.prox_lib
.
prox_l1_norm(v, t)
: the $\ell_1$ norm's proximal operator.prox_l2_norm(v, t)
: the $\ell_2$ norm's proximal operator.make_scaled_prox_nuc_norm(shape: tuple[int, int], t_scale: float)
: generates the proximal operator $\text{prox}_{r}: \mathbb R^{\mathtt{shape}} \to \mathbb R^{\mathtt{shape}}$ of $b \mapsto \mathtt{t_scale} \sum_i \sigma_i(b)$combine_proxs(shape: list[int], proxs: list)
: if there are two regularizers $r_1$, $r_2$ such that the regularizer for the problem is given by $r(b, b') = r_1(b) + r_2(b')$, then this function should be called with([dim(b), dim(b')], [prox_r_1, prox_r_2])
.scale_prox(prox, t_scale)
: takes a proximal operator of $r$, and returns the proximal operator of $\mathtt{t_scale} r$.
Citing
If you use this code in a research project, please cite the associated paper.
@article{nobel2022tractable,
title={Tractable evalutaion of {S}tein's {U}nbiased {R}isk {E}stimate with convex regularizers},
author={Parth Nobel \and Emmanuel Cand\`es \and Stephen Boyd},
publisher = {arXiv},
year = {2022},
note = {arXiv:2211.05947 [math.ST]},
url = {https://arxiv.org/abs/2211.05947},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file surecr-0.1.2.tar.gz
.
File metadata
- Download URL: surecr-0.1.2.tar.gz
- Upload date:
- Size: 13.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 60551913bdc47dff598841eba4763cc1e5b6dfd945230ac86463bcf4cecbeaa2 |
|
MD5 | 3c17b203b8741c5d5c0cd1519acd74fe |
|
BLAKE2b-256 | d6b913503cd7fef7d1344194d02b149125b6c4855e0673120f782e62024c50d0 |
File details
Details for the file surecr-0.1.2-py3-none-any.whl
.
File metadata
- Download URL: surecr-0.1.2-py3-none-any.whl
- Upload date:
- Size: 16.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c1c859e5eb2a75f9f805acdc004c2e751e88b28e40e86469f48606af80763245 |
|
MD5 | 53dd40c35c3bb6593ae20fa5f501d357 |
|
BLAKE2b-256 | 5c2c082545dd14ddf87fba9d8ea6b73cc321719b99dd864ef9dd470f980150a7 |