Skip to main content

A library to define abstract linear operators, and associated algebra and matrix-free algorithms, that works with pyTorch Tensors.

Project description

An Abstract Linear Operator Library for pyTorch

This library implements a generic structure for abstract linear operators and enables a number of standard operations on them:

  • Arithmetic: A + B, A - B, -A, A @ B all work exactly as expected to combine linear operators.
  • Indexing: A[k:ell,m:n] works as expected.
  • Solves: Ax = b can be solved with CG for PSD matrices, minres for symmetric matrices, LSQR (to be implemented), or LSMR (to be implemented).
  • Trace estimation: The trace of square matrices, can be estimated via Hutch++ and Hutchinson's estimator.
  • Diamond-Boyd stochastic equilibration
  • Randomized Nyström Preconditioning
  • Automatic adjoint operator generation.

Using LinearOperators

The public API of the LinearOperator library is that every LinearOperator has the following properties and methods:

class LinearOperator:

    # Properties
    shape: tuple[int, int]
    T: LinearOperator
    supports_operator_matrix: bool
    device: torch.Device

    # Matrix multiply
    def __matmul__(self, b: torch.Tensor) -> torch.Tensor: ...
    def __rmatmul__(self, b: torch.Tensor) -> torch.Tensor: ...

    def __matmul__(self, b: LinearOperator) -> LinearOperator: ...
    def __rmatmul__(self, b: LinearOperator) -> LinearOperator: ...
    
    # Linear Solve Methods
    def solve_I_p_lambda_AT_A_x_eq_b(self,
        lambda_: float,
        b: torch.Tensor,
        x0: torch.Tensor | None=None,
        *, precondition: None | Literal['nsytrom'], hot=False) -> torch.Tensor: ...

    def solve_A_x_eq_b(self,
        b: torch.Tensor,
        x0: torch.Tensor | None=None) -> torch.Tensor: ...

    # Transformations on LinearOperator
    def __mul__(self, c: float) -> LinearOperator: ...
    def __rmul__(self, c: float) -> LinearOperator: ...

    def __truediv__(self, c: float) -> LinearOperator: ...

    def __pow__(self, k: int) -> LinearOperator: ...

    def __add__(self, c: LinearOperator) -> LinearOperator: ...

    def __sub__(self, c: LinearOperator) -> LinearOperator: ...

    def __neg__(self) -> LinearOperator: ...

    def __pos__(self) -> LinearOperator: ...

    def __getitem__(self, key) -> LinearOperator: ...

The following functions are available in the root of the library:

def operator_matrix_product(A: LinearOperator, M: torch.Tensor) -> torch.Tensor: ...
def aslinearoperator(A: torch.Tensor | LinearOperator) -> LinearOperator: ...
def vstack(ops: list[LinearOperator] | tuple[LinearOperator, ...]) -> LinearOperator: ...
def hstack(ops: list[LinearOperator] | tuple[LinearOperator, ...]) -> LinearOperator: ...

# To be implemented:
def bmat(ops: list[list[LinearOperator]]) -> LinearOperator: ... # Optimizes out ZeroOperator

The following functions are available in linops.trace for trace estimation:

def hutchpp(A: lo.LinearOperator, m: int) -> float: ...
def hutchinson(A: lo.LinearOperator, m: int) -> float: ...

linops.equilibration contains equilibrate and symmetric_equilibrate. Their public API is not finalized, if you wish to use them it is recommend you read the source code.

Creating Linear Operators

Linear operators can be constructed in the following way:

  • Creating a sub-class of LinearOperator
  • Calling one of the following constructors:
    • IdentityOperator(n: int)
    • DiagonalOperator(diag: torch.Tensor): where diag is a 1D torch tensor.
    • MatrixOperator(M: torch.Tensor): where M is a 2D torch tensor.
    • SelectionOperator(shape: tuple[int, int], idxs: slice | list[int | slice])
    • KKTOperator(H: LinearOperator, A: LinearOperator): where H is a square LinearOperator and A is a LinearOperator
    • VectorJacobianOperator(f: torch.Tensor, x: torch.Tensor): where f is the output of the function being differentiated which has a torch autograd value and x is the vector on which ensures_grad was called.
    • ZeroOperator(shape: tuple[int, int])
  • Combining operators via:
    • A + B, A - B, A @ B for A, B linear operators
    • hstack, vstack
    • A, c A, A / c, v * A, A / v for scalar c and vector v.

Implementing LinearOperators

To implement a LinearOperator the following are mandatory:

  • Set _shape: tuple[int, int] to the shape of the operator.
  • Set device appropriately, if the operator requires vectors to be on a particular device.
  • Implement a method def _matmul_impl(self, v: torch.Tensor) -> torch.Tensor: ... that implements your matrix vector product.

The following are recommended to improve performance:

  • If your _matmul_impl method handles matrix inputs correctly, set supports_operator_matrix: bool to True.
  • If it is possible to describe the adjoint operator, set _adjoint: LinearOperator to point to the adjoint of your operator. If you do not compute this, then one will be autogenerated by differentiating through your _matmul_impl.

It is suggested that, if possible, you replace any other methods with specialized implementations.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_linops-0.2.0.tar.gz (20.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_linops-0.2.0-py3-none-any.whl (21.4 kB view details)

Uploaded Python 3

File details

Details for the file torch_linops-0.2.0.tar.gz.

File metadata

  • Download URL: torch_linops-0.2.0.tar.gz
  • Upload date:
  • Size: 20.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for torch_linops-0.2.0.tar.gz
Algorithm Hash digest
SHA256 93c052cea5cb348f9613d51eb41f05fec9065dcc4bdc5a2c4d9f5c8b6cd284bb
MD5 cad1fb17549b862fa0d8bbe858bb4f60
BLAKE2b-256 58c6c83f1db2f207911ec150451babefd3b04085f82a08193a5d9be1a58055b7

See more details on using hashes here.

Provenance

The following attestation bundles were made for torch_linops-0.2.0.tar.gz:

Publisher: build.yml on cvxgrp/torch_linops

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file torch_linops-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: torch_linops-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 21.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for torch_linops-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 85fb3ef7191eaa1e0903ed99150b89406b7bce0b11f14fbea3e3b404a19f1108
MD5 7bdc74cdd919eaafa86a958e446b6056
BLAKE2b-256 562144676524c678ffe7bfd6a084a7a5c087c7cf7df48424a2157dd77df04e90

See more details on using hashes here.

Provenance

The following attestation bundles were made for torch_linops-0.2.0-py3-none-any.whl:

Publisher: build.yml on cvxgrp/torch_linops

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page