Skip to main content

Sketched matrix decompositions for PyTorch

Project description

Skerch logo, light mode Skerch logo, dark mode

skerch: Sketched linear operations for PyTorch

PyPI Docs CI Tests
PyPI - Downloads Documentation Status GitHub Actions Workflow Status Coverage Status

What is skerch?

skerch is a Python package to compute different sketched linear operations, such as SVD/EIGH, diagonal/triangular approximations and operator norms. See the documentation for more details and usage examples. Main features:

  • Built on top of PyTorch, naturally supports CPU and CUDA, as well as complex datatypes. Very few dependencies otherwise
  • Rich API for matrix-free linear operators, including matrix-free noise sources (Rademacher, Gaussian, SSRFT...)
  • Efficient parallelized and distributed computations
  • Support for out-of-core operations via HDF5
  • A-posteriori verification tools to test accuracy of sketched approximations
  • modular and extendible design, for easy adaption to new settings and operations

Why sketches?

Sketched methods are a good choice whenever we are dealing with large objects that can be approximated by smaller substructures (e.g. a low-rank approximation of a large matrix). Thanks to the random measurements (i.e. the "sketches"), we can directly obtain the small approximations, without having to store or compute the large object. This works with very few assumptions about how the smaller substructure looks like.

For example, if we have a large linear operator of dimensionality (N, N) that doesn't fit in memory, but has rank k, we can directly retrieve its top-k singular components with only O(Nk) storage, as opposed to the intractable O(N^2) (see picture below for an intuition). As a bonus, this technique is numerically stable and can be parallelized, which often results in substantial speedups.

Intuition for low-rank sketch-and-solve.

Why skerch?

On paper, sketched methods only require our linear operators to satisfy the following bare bones interface:

class MyLinOp:
    def __init__(self, shape):
        self.shape = shape

    def __matmul__(self, x):
        return "... implement A @ x ..."

    def __rmatmul__(self, x):
        return "... implement x @ A ..."

Anything more than this is not really required. In most cases, libraries do require more complicated interfaces, and this limits the application scenarios, or introduces substantial overhead to developers.

skerch is specifically designed to work on this bare-bones interface. Furthermore, its highly modular architecture allows users to exchange and modify different components of the sketched methods.

As a bonus, skerch is built on top of PyTorch, and with very few dependencies otherwise, so it supports a broad variety of platforms and datatypes (including e.g. complex datatypes on GPU). skerch also supports in-core and out-of-core parallelizations (e.g. via HDF5 tensor databases), providing good scalability in memory and runtime. The documentation examples illustrate all of the above points.

In summary, skerch brings sketched methods to you with minimal overhead, and retaining good performance, resulting in overall faster development and runtimes. Give it a try!

Installation and basic usage

Install via:

pip install skerch

Then, given any linear operator lop that implements the bare-minimum interface .shape = (h, w) and lop @ x, x @ lop, we can compute the sketched SVD as follows (skerch also provides functionality to estimate EIGH, norm, diagonals...):

from skerch.algorithms import ssvd

U, S, Vh = ssvd(lop, device, dtype, num_outer, seed=12345, recovery_type="nystrom")

With the nystrom recovery, this method requires a total of 2 * num_outer measurements, and yields a thin SVD estimation where lop is approximated by (U * S) @ Vh, and U.shape = (h, num_outer), S.shape = (num_outer,), Vh.shape = (num_outer, w).

If num_outer is close enough to covering the rank of lop, this yields an accurate recovery (see documentation examples). But how can we make sure?

The library also implements cheap a-posteriori methods to estimate the error of the obtained sketched approximation, without requiring to know lop:

from skerch.a_posteriori import apost_error, apost_error_bounds
from skerch.linops import CompositeLinOp, DiagonalLinOp

lop_approx = CompositeLinOp([("U", U), ("S", DiagonalLinOp(S)), ("Vh", Vh)])

(lop_f2, approx_f2, err2), _ = apost_error(lop, lop_approx, device, dtype, num_meas=30, seed=54321)

print("Estimated Frob(op):", lop_f2**0.5)
print("Estimated Frob(sketched_op):", approx_f2**0.5)
print("Estimated Frobenius Error:", err2**0.5)

This technique makes use of a number of test measurements that must be independent from lop and the sketch measurements (make sure to use a different seed). For 30 measurements and complex-valued lop, the probability of err2**0.5 being wrong by at most 50% can be queried as follows:

python -m skerch post_bounds --apost_n=30 --apost_err=0.5 --is_complex
# returns {'LOWER: P(err<=0.5x)': 0.0030445096757934554, 'HIGHER: P(err>=1.5x)': 0.05865709397802224}

See the documentation examples and API docs for more details.

Developers

Contributions are welcome under this repo's LICENSE.

Feel free to open an issue with bug reports, feature requests, etc.

The documentation also contains a For Developers section with useful guidelines to interact with this repo and propose pull requests.

Researchers

If this library is useful for your work, please consider citing it:

@manual{fernandez2024skerch,
  title={{S}kerch: Sketched matrix decompositions for {PyTorch}},
  author={Andres Fernandez},
  year={2024},
  url={https://github.com/andres-fr/skerch},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

skerch-1.2.0.tar.gz (1.6 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

skerch-1.2.0-py3-none-any.whl (103.7 kB view details)

Uploaded Python 3

File details

Details for the file skerch-1.2.0.tar.gz.

File metadata

  • Download URL: skerch-1.2.0.tar.gz
  • Upload date:
  • Size: 1.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for skerch-1.2.0.tar.gz
Algorithm Hash digest
SHA256 5116bff8781c24b477dbaf311e8336c9f85b7d16c73d28db4368733ec93a48a4
MD5 d5282ac730dc061e83e092ad3136fafd
BLAKE2b-256 947cc7737287489704cee7a4edaf8db5907e80cba266adaf3be3ee3a1b97ff5f

See more details on using hashes here.

Provenance

The following attestation bundles were made for skerch-1.2.0.tar.gz:

Publisher: ci.yaml on andres-fr/skerch

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file skerch-1.2.0-py3-none-any.whl.

File metadata

  • Download URL: skerch-1.2.0-py3-none-any.whl
  • Upload date:
  • Size: 103.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for skerch-1.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 46e5bab4bbadc994a540a0ef04639f3a518f7877a12c6d9436adf10aae2b3bb1
MD5 103a046fb6ea14987342d617438bbb70
BLAKE2b-256 f1ad931294d24fa3602d462761682623813abfd82f3500f43c3203abc0c35cc6

See more details on using hashes here.

Provenance

The following attestation bundles were made for skerch-1.2.0-py3-none-any.whl:

Publisher: ci.yaml on andres-fr/skerch

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page