Skip to main content

A library to define abstract linear operators, and associated algebra and matrix-free algorithms, that works with pyTorch Tensors.

Project description

An Abstract Linear Operator Library for pyTorch

This library implements a generic structure for abstract linear operators and enables a number of standard operations on them:

  • Arithmetic: A + B, A - B, -A, A @ B all work exactly as expected to combine linear operators.
  • Indexing: A[k:ell,m:n] works as expected.
  • Solves: Ax = b can be solved with CG for PSD matrices, minres for symmetric matrices, LSQR (to be implemented), or LSMR (to be implemented).
  • Trace estimation: The trace of square matrices, can be estimated via Hutch++ and Hutchinson's estimator.
  • Diamond-Boyd stochastic equilibration
  • Randomized Nyström Preconditioning
  • Automatic adjoint operator generation.

Using LinearOperators

The public API of the LinearOperator library is that every LinearOperator has the following properties and methods:

class LinearOperator:

    # Properties
    shape: tuple[int, int]
    T: LinearOperator
    supports_operator_matrix: bool
    device: torch.Device

    # Matrix multiply
    def __matmul__(self, b: torch.Tensor) -> torch.Tensor: ...
    def __rmatmul__(self, b: torch.Tensor) -> torch.Tensor: ...

    def __matmul__(self, b: LinearOperator) -> LinearOperator: ...
    def __rmatmul__(self, b: LinearOperator) -> LinearOperator: ...
    
    # Linear Solve Methods
    def solve_I_p_lambda_AT_A_x_eq_b(self,
        lambda_: float,
        b: torch.Tensor,
        x0: torch.Tensor | None=None,
        *, precondition: None | Literal['nsytrom'], hot=False) -> torch.Tensor: ...

    def solve_A_x_eq_b(self,
        b: torch.Tensor,
        x0: torch.Tensor | None=None) -> torch.Tensor: ...

    # Transformations on LinearOperator
    def __mul__(self, c: float) -> LinearOperator: ...
    def __rmul__(self, c: float) -> LinearOperator: ...

    def __truediv__(self, c: float) -> LinearOperator: ...

    def __pow__(self, k: int) -> LinearOperator: ...

    def __add__(self, c: LinearOperator) -> LinearOperator: ...

    def __sub__(self, c: LinearOperator) -> LinearOperator: ...

    def __neg__(self) -> LinearOperator: ...

    def __pos__(self) -> LinearOperator: ...

    def __getitem__(self, key) -> LinearOperator: ...

The following functions are available in the root of the library:

def operator_matrix_product(A: LinearOperator, M: torch.Tensor) -> torch.Tensor: ...
def aslinearoperator(A: torch.Tensor | LinearOperator) -> LinearOperator: ...
def vstack(ops: list[LinearOperator] | tuple[LinearOperator, ...]) -> LinearOperator: ...
def hstack(ops: list[LinearOperator] | tuple[LinearOperator, ...]) -> LinearOperator: ...

# To be implemented:
def bmat(ops: list[list[LinearOperator]]) -> LinearOperator: ... # Optimizes out ZeroOperator

The following functions are available in linops.trace for trace estimation:

def hutchpp(A: lo.LinearOperator, m: int) -> float: ...
def hutchinson(A: lo.LinearOperator, m: int) -> float: ...

linops.equilibration contains equilibrate and symmetric_equilibrate. Their public API is not finalized, if you wish to use them it is recommend you read the source code.

Creating Linear Operators

Linear operators can be constructed in the following way:

  • Creating a sub-class of LinearOperator
  • Calling one of the following constructors:
    • IdentityOperator(n: int)
    • DiagonalOperator(diag: torch.Tensor): where diag is a 1D torch tensor.
    • MatrixOperator(M: torch.Tensor): where M is a 2D torch tensor.
    • SelectionOperator(shape: tuple[int, int], idxs: slice | list[int | slice])
    • KKTOperator(H: LinearOperator, A: LinearOperator): where H is a square LinearOperator and A is a LinearOperator
    • VectorJacobianOperator(f: torch.Tensor, x: torch.Tensor): where f is the output of the function being differentiated which has a torch autograd value and x is the vector on which ensures_grad was called.
    • ZeroOperator(shape: tuple[int, int])
  • Combining operators via:
    • A + B, A - B, A @ B for A, B linear operators
    • hstack, vstack
    • A, c A, A / c, v * A, A / v for scalar c and vector v.

Implementing LinearOperators

To implement a LinearOperator the following are mandatory:

  • Set _shape: tuple[int, int] to the shape of the operator.
  • Set device appropriately, if the operator requires vectors to be on a particular device.
  • Implement a method def _matmul_impl(self, v: torch.Tensor) -> torch.Tensor: ... that implements your matrix vector product.

The following are recommended to improve performance:

  • If your _matmul_impl method handles matrix inputs correctly, set supports_operator_matrix: bool to True.
  • If it is possible to describe the adjoint operator, set _adjoint: LinearOperator to point to the adjoint of your operator. If you do not compute this, then one will be autogenerated by differentiating through your _matmul_impl.

It is suggested that, if possible, you replace any other methods with specialized implementations.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-linops-0.1.3.tar.gz (20.3 kB view details)

Uploaded Source

Built Distribution

torch_linops-0.1.3-py3-none-any.whl (22.4 kB view details)

Uploaded Python 3

File details

Details for the file torch-linops-0.1.3.tar.gz.

File metadata

  • Download URL: torch-linops-0.1.3.tar.gz
  • Upload date:
  • Size: 20.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.6

File hashes

Hashes for torch-linops-0.1.3.tar.gz
Algorithm Hash digest
SHA256 e652a16b40d2a40f977537554b75abcbb0e4498634940b009d37fed8007c3bd6
MD5 a117ee11897f0447094445c52efaf10a
BLAKE2b-256 346ad91b3c41a92dfa9cc6c142e3aac6623f9d18449499a1591a36a5d922bd8f

See more details on using hashes here.

File details

Details for the file torch_linops-0.1.3-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_linops-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 ac22ae92d24b0386a3518a957e674d572b13eb717c0ab113c51766566ddec403
MD5 797d1f517fd23b2806fd520afde4956a
BLAKE2b-256 fa36bbc9bb3996648252d7e573d1a13d4363f85b8c07620e409b138e891c760e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page