Skip to main content

Sketched matrix decompositions for PyTorch

Project description

Skerch logo, light mode Skerch logo, dark mode

skerch: Sketched matrix decompositions for PyTorch

PyPI Docs CI Tests
PyPI - Downloads Documentation Status GitHub Actions Workflow Status Coverage Status

skerch is a Python package to compute diagonal decompositions (SVD, Hermitian Eigendecomposition) of linear operators via sketched methods.

  • Built on top of PyTorch, with natural support for CPU and CUDA interoperability, and very few dependencies otherwise
  • Works on matrices and matrix-free operators of potentially very large dimensionality
  • Support for sketched measurements in a fully distributed fashion via HDF5 databases

References:

See the documentation for more details.

Installation and basic usage

Install via:

pip install skerch

The sketched SVD of a linear operator op can be then computed simply via:

from skerch.decompositions import ssvd

q, u, s, vt, pt = ssvd(
    op,
    op_device=DEVICE,
    op_dtype=DTYPE,
    outer_dim=NUM_OUTER,
    inner_dim=NUM_INNER,
)

Where the number of outer and inner measurements for the sketch is specified, and q @ u @ diag(s) @ vt @ pt is a PyTorch matrix that approximates op. The op object must simply satify the following criteria:

  • It must have a op.shape = (height, width) attribute
  • It must implement the w = op @ v right-matmul operator, receiving and returning PyTorch vectors/matrices
  • It must implement the w = v @ op left-matmul operator, receiving and returning PyTorch vectors/matrices

skerch provides a convenience PyTorch wrapper for the cases where op interacts with NumPy arrays instead (e.g. SciPy linear operators like the ones used in CurvLinOps).

To get a good suggestion of the number of measurements required for a given shape and budget, simply run:

python -m skerch prio_hpars --shape=100,200 --budget=12345

The library also implements cheap a-posteriori methods to estimate the error of the obtained sketched approximation:

from skerch.a_posteriori import a_posteriori_error
from skerch.linops import CompositeLinOp, DiagonalLinOp

# (q, u, s, vt, pt) previously computed via ssvd
sketched_op = CompositeLinOp(
    (
        ("Q", q),
        ("U", u),
        ("S", DiagonalLinOp(s)),
        ("Vt", vt),
        ("Pt", pt),
    )
)

(f1, f2, frob_err) = a_posteriori_error(
    op, sketched_op, NUM_A_POSTERIORI, dtype=DTYPE, device=DEVICE
)[0]
print("Estimated Frob(op):", f1**0.5)
print("Estimated Frob(sketched_op):", f2**0.5)
print("Estimated Frobenius Error:", frob_err**0.5)

For a given NUM_A_POSTERIORI measurements (30 is generally OK), the probability of frob_err**0.5 being wrong by a certain amount can be queried as follows:

python -m skerch post_bounds --apost_n=30 --apost_err=0.5

See Getting Started, Examples, and API docs for more details.

Developers

Contributions are most welcome under this repo's LICENSE. Feel free to open an issue with bug reports, feature requests, etc.

The documentation contains a For Developers section with useful guidelines to interact with this repo.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

skerch-0.6.1.tar.gz (1.5 MB view details)

Uploaded Source

Built Distribution

skerch-0.6.1-py3-none-any.whl (61.7 kB view details)

Uploaded Python 3

File details

Details for the file skerch-0.6.1.tar.gz.

File metadata

  • Download URL: skerch-0.6.1.tar.gz
  • Upload date:
  • Size: 1.5 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.0.0 CPython/3.12.2

File hashes

Hashes for skerch-0.6.1.tar.gz
Algorithm Hash digest
SHA256 3931df392c838a5ed39c450cdbdcf3581a73e5af44fc40609c4c6e7f092d2902
MD5 0e84429aa2e02e6520d70e4b902638ba
BLAKE2b-256 0a4092f77820c450d40f426d5b4712435c6fe7bacd21a82d2492853e3c168cde

See more details on using hashes here.

File details

Details for the file skerch-0.6.1-py3-none-any.whl.

File metadata

  • Download URL: skerch-0.6.1-py3-none-any.whl
  • Upload date:
  • Size: 61.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.0.0 CPython/3.12.2

File hashes

Hashes for skerch-0.6.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4eba40af3b29d7296d13e94c657dca8effaede675e5d45750b7626aa133f1c96
MD5 32e77bb2376cd6b25d621362b27be1ff
BLAKE2b-256 d9ee1b722c9c6824c174e28f0be6a6470fed28f20a0f0da117999dfd1b30be9b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page