Skip to main content

This package enables tractable evaluation of Stein's Unbiased Risk Estimate on convexly regularized estimators.

Project description

SURE-CR

This package enables tractable evaluation of Stein's Unbiased Risk Estimate on convexly regularized estimators.

For an estimator of the mean of a normally-distributed random vector $y$ with known covariance matrix $\sigma^2 I$ given by

$$ \hat\mu(y) = \mathcal A ~ \text{argmin} \frac{1}{2} \lVert\mathcal A b - y\rVert_2^2 + r(b) $$

where $r: \mathbb R^p \to \mathbb R$ is a convex function and $\mathcal A: \mathbb R^p \to \mathbb R^d$ is a linear operator, this package provides methods to compute Stein's Unbiased Risk Estimate of $\hat\mu$:

$$ SURE(\hat\mu, y) = -n \sigma^2 + \lVert\hat\mu(y) - y\rVert_2^2 + 2 \sigma^2 \nabla \cdot \hat\mu(y). $$

$SURE(\hat\mu, y)$ is a good estimate of the $\ell_2$ risk of $\hat\mu$, especially for high dimensional problems.

Installation

SURE-CR can be installed with pip and conda. By default, the conda instillation does not support using the CVXPYSolver, and the cvxpylayers library will have to be installed seperately to use it.

To install with pip:

$ pip install surecr

To install with conda:

$ conda install -c stanfordcvxgrp -c conda-forge sure-cr

Examples

The easiest way to start using SURE-CR may be to read the examples:

Usage

There are three key things in this package:

  • The SURE class
  • The Solver class and its subclasses CVXPYSolver, FISTASolver, and ADMMSolver
  • The prox_lib helper library

The SURE Class

The SURE class has the following API:

class SURE:
    def __init__(self, variance: float, solver: Solver): ...

    def compute(self, y: torch.Tensor, divergence_parameters={}) -> float:
        """
        Computes and returns SURE for the estimator computed by the solver
        at the point y.

        Currently, divergence_parameters can contain the key "m" to indicate
        how many samples to use during the divergence estimation (which
        dominates the runtime at high dimensions). The default is for m to be
        102.

        In the future we may switch to A-Hutch++ and may change what options
        the divergence_parameters specifies.
        """

    @property
    def solution(self) -> torch.Tensor:
        """
        Returns solver.solve(y) from the last compute call.
        """

    def runtimes(self) -> TypedDict('Runtimes', solver=float, divergence=float):
        """
        Returns how long it took for the solver to run and how long it took
        the divergence estimator to run during the last compute call.
        """

The Solver class

Most uses of the library should use one of the existing Solver subclasses. They have the following APIs:

The three notable Solver instances provided by this library have the following constructors:

class FISTASolver(Solver):
    def __init__(self, A: linops.LinearOperator,
                       prox_R: Callable[[torch.Tensor, float | torch.Tensor], torch.Tensor],
                       x0: torch.Tensor,
                       device=None,
                       lipschitz_iterations=20,
                       lipschitz_vec=None,
                       *, max_iters=5000, eps=1e-3):
        """
        This solver solves problems of the form with a variant on FISTA:
              min. 1/2 ||A b - y||_2^2 + r(b)
        and estimates the mean of y with A b^* where b^* is the optimal b.

        A is a linear operator defined using <https://github.com/cvxgrp/torch_linops>

        prox_R is a differentiable-with-respect-to-its-first-argument function to
            find the optimal point b for a (v, t) pair of
              min. t r(b) + 1/2 ||b - v||_2^2

        x0 is the point where we begin iterations, it must be chosen
            indepentently of y.

        lipschitz_iterations is how many iterations of the power method to use
        to approximate the largest eigenvalue of A^T A

        lipschitz_vec is the vector to start the power method. By default, a
        vector of all 1s is used. If this vector is orthogonal to the largest
        eigenvector of A^T A, this argument is mandatory.

        max_iters, eps control when iterations stop.

        """

class ADMMSolver(Solver):
    def __init__(self, A: linops.LinearOperator,
                       prox_R: Callable[[torch.Tensor, float | torch.Tensor], torch.Tensor],
                       x0: torch.Tensor,
                       device=None,
                       *, max_iters=1000, eps_rel=1e-3, eps_abs=1e-6):
        """
        This solver solves problems of the form with a variant on ADMM:
              min. 1/2 ||A b - y||_2^2 + r(b)
        and estimates the mean of y with A b^* where b^* is the optimal b.

        A is a linear operator defined using <https://github.com/cvxgrp/torch_linops>

        prox_R is a differentiable-with-respect-to-its-first-argument function to
            find the optimal point b for a (v, t) pair of
              min. t r(b) + 1/2 ||b - v||_2^2

        x0 is the point where we begin iterations, it must be chosen
            indepentently of y.

        max_iters, eps_rel, eps_abs control when iterations stop.
        """

class CVXPYSolver(Solver):
    def __init__(self, problem: cp.Problem,
                       y_parameter: cp.Parameter, 
                       variables: list[cp.Variable], 
                       estimate: Callable[[list[torch.Tensor]], torch.Tensor]):
        """
        problem must be a CVXPY problem with a single paremeter, y_parameter,
            and variables y_variable.

        estimate must be function which takes tensors with values for each variable
            and returns the estimate.

        WARNING: This solver has poor performance on large problems, and can
        have undetected poor accuracy on some moderately-sized problems.
        """

If you wish to implement, Solver, it has has the following API, where T is any type of the implementation's choice:

class Solver:

    def solve(self, y: torch.Tensor) -> T:
        """
        Returns intermediate value used to estimate the mean of the distribution
        y is sampled from.
        """

    def estimate(self, beta: T) -> torch.Tensor: ...
        """
        Given the output of a solve call, returns the estimate of the mean of the
        distribution y was sampled from.
        """

Note that for a given instance s of a solver class, s.estimate(s.solve(y)) must be differentiable via torch's backpropagation.

The prox_lib library

Since FISTASolver and ADMMSolver both require a proximal operator for the regularizer we provide some methods here to help construct proximal operators:

There are also many helper methods in surecr.prox_lib.

  • prox_l1_norm(v, t): the $\ell_1$ norm's proximal operator.
  • prox_l2_norm(v, t): the $\ell_2$ norm's proximal operator.
  • make_scaled_prox_nuc_norm(shape: tuple[int, int], t_scale: float): generates the proximal operator $\text{prox}_{r}: \mathbb R^{\mathtt{shape}} \to \mathbb R^{\mathtt{shape}}$ of $b \mapsto \mathtt{t_scale} \sum_i \sigma_i(b)$
  • combine_proxs(shape: list[int], proxs: list): if there are two regularizers $r_1$, $r_2$ such that the regularizer for the problem is given by $r(b, b') = r_1(b) + r_2(b')$, then this function should be called with ([dim(b), dim(b')], [prox_r_1, prox_r_2]).
  • scale_prox(prox, t_scale): takes a proximal operator of $r$, and returns the proximal operator of $\mathtt{t_scale} r$.

Citing

If you use this code in a research project, please cite the associated paper.

@article{nobel2022tractable,
    title={Tractable evalutaion of {S}tein's {U}nbiased {R}isk {E}stimate with convex regularizers},
    author={Parth Nobel \and Emmanuel Cand\`es \and Stephen Boyd},
    publisher = {arXiv},
    year = {2022},
    note = {arXiv:2211.05947 [math.ST]},
    url = {https://arxiv.org/abs/2211.05947},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

surecr-0.1.2.tar.gz (13.5 kB view details)

Uploaded Source

Built Distribution

surecr-0.1.2-py3-none-any.whl (16.2 kB view details)

Uploaded Python 3

File details

Details for the file surecr-0.1.2.tar.gz.

File metadata

  • Download URL: surecr-0.1.2.tar.gz
  • Upload date:
  • Size: 13.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.10

File hashes

Hashes for surecr-0.1.2.tar.gz
Algorithm Hash digest
SHA256 60551913bdc47dff598841eba4763cc1e5b6dfd945230ac86463bcf4cecbeaa2
MD5 3c17b203b8741c5d5c0cd1519acd74fe
BLAKE2b-256 d6b913503cd7fef7d1344194d02b149125b6c4855e0673120f782e62024c50d0

See more details on using hashes here.

File details

Details for the file surecr-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: surecr-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 16.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.10

File hashes

Hashes for surecr-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 c1c859e5eb2a75f9f805acdc004c2e751e88b28e40e86469f48606af80763245
MD5 53dd40c35c3bb6593ae20fa5f501d357
BLAKE2b-256 5c2c082545dd14ddf87fba9d8ea6b73cc321719b99dd864ef9dd470f980150a7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page