Sketched matrix decompositions for PyTorch
Project description
skerch: Sketched linear operations for PyTorch
What is skerch?
skerch is a Python package to compute different sketched linear operations, such as SVD/EIGH, diagonal/triangular approximations and operator norms. See the documentation for more details and usage examples. Main features:
- Built on top of PyTorch, naturally supports CPU and CUDA, as well as complex datatypes. Very few dependencies otherwise
- Rich API for matrix-free linear operators, including matrix-free noise sources (Rademacher, Gaussian, SSRFT...)
- Efficient parallelized and distributed computations
- Support for out-of-core operations via HDF5
- A-posteriori verification tools to test accuracy of sketched approximations
- modular and extendible design, for easy adaption to new settings and operations
Why sketches?
Sketched methods are a good choice whenever we are dealing with large objects that can be approximated by smaller substructures (e.g. a low-rank approximation of a large matrix). Thanks to the random measurements (i.e. the "sketches"), we can directly obtain the small approximations, without having to store or compute the large object. This works with very few assumptions about how the smaller substructure looks like.
For example, if we have a large linear operator of dimensionality (N, N) that doesn't fit in memory, but has rank k, we can directly retrieve its top-k singular components with only O(Nk) storage, as opposed to the intractable O(N^2) (see picture below for an intuition). As a bonus, this technique is numerically stable and can be parallelized, which often results in substantial speedups.
Why skerch?
On paper, sketched methods only require our linear operators to satisfy the following bare bones interface:
class MyLinOp:
def __init__(self, shape):
self.shape = shape
def __matmul__(self, x):
return "... implement A @ x ..."
def __rmatmul__(self, x):
return "... implement x @ A ..."
Anything more than this is not really required. In most cases, libraries do require more complicated interfaces, and this limits the application scenarios, or introduces substantial overhead to developers.
skerch is specifically designed to work on this bare-bones interface. Furthermore, its highly modular architecture allows users to exchange and modify different components of the sketched methods.
As a bonus, skerch is built on top of PyTorch, and with very few dependencies otherwise, so it supports a broad variety of platforms and datatypes (including e.g. complex datatypes on GPU).
skerch also supports in-core and out-of-core parallelizations (e.g. via HDF5 tensor databases), providing good scalability in memory and runtime. The documentation examples illustrate all of the above points.
In summary, skerch brings sketched methods to you with minimal overhead, and retaining good performance, resulting in overall faster development and runtimes. Give it a try!
Installation and basic usage
Install via:
pip install skerch
Then, given any linear operator lop that implements the bare-minimum interface .shape = (h, w) and lop @ x, x @ lop, we can compute the sketched SVD as follows (skerch also provides functionality to estimate EIGH, norm, diagonals...):
from skerch.algorithms import ssvd
U, S, Vh = ssvd(lop, device, dtype, num_outer, seed=12345, recovery_type="nystrom")
With the nystrom recovery, this method requires a total of 2 * num_outer measurements, and yields a thin SVD estimation where lop is approximated by (U * S) @ Vh, and U.shape = (h, num_outer), S.shape = (num_outer,), Vh.shape = (num_outer, w).
If num_outer is close enough to covering the rank of lop, this yields an accurate recovery (see documentation examples). But how can we make sure?
The library also implements cheap a-posteriori methods to estimate the error of the obtained sketched approximation, without requiring to know lop:
from skerch.a_posteriori import apost_error, apost_error_bounds
from skerch.linops import CompositeLinOp, DiagonalLinOp
lop_approx = CompositeLinOp([("U", U), ("S", DiagonalLinOp(S)), ("Vh", Vh)])
(lop_f2, approx_f2, err2), _ = apost_error(lop, lop_approx, device, dtype, num_meas=30, seed=54321)
print("Estimated Frob(op):", lop_f2**0.5)
print("Estimated Frob(sketched_op):", approx_f2**0.5)
print("Estimated Frobenius Error:", err2**0.5)
This technique makes use of a number of test measurements that must be independent from lop and the sketch measurements (make sure to use a different seed). For 30 measurements and complex-valued lop, the probability of err2**0.5 being wrong by at most 50% can be queried as follows:
python -m skerch post_bounds --apost_n=30 --apost_err=0.5 --is_complex
# returns {'LOWER: P(err<=0.5x)': 0.0030445096757934554, 'HIGHER: P(err>=1.5x)': 0.05865709397802224}
See the documentation examples and API docs for more details.
Developers
Contributions are welcome under this repo's LICENSE.
Feel free to open an issue with bug reports, feature requests, etc.
The documentation also contains a For Developers section with useful guidelines to interact with this repo and propose pull requests.
Researchers
If this library is useful for your work, please consider citing it:
@manual{fernandez2024skerch,
title={{S}kerch: Sketched matrix decompositions for {PyTorch}},
author={Andres Fernandez},
year={2024},
url={https://github.com/andres-fr/skerch},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file skerch-1.2.0.tar.gz.
File metadata
- Download URL: skerch-1.2.0.tar.gz
- Upload date:
- Size: 1.6 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5116bff8781c24b477dbaf311e8336c9f85b7d16c73d28db4368733ec93a48a4
|
|
| MD5 |
d5282ac730dc061e83e092ad3136fafd
|
|
| BLAKE2b-256 |
947cc7737287489704cee7a4edaf8db5907e80cba266adaf3be3ee3a1b97ff5f
|
Provenance
The following attestation bundles were made for skerch-1.2.0.tar.gz:
Publisher:
ci.yaml on andres-fr/skerch
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
skerch-1.2.0.tar.gz -
Subject digest:
5116bff8781c24b477dbaf311e8336c9f85b7d16c73d28db4368733ec93a48a4 - Sigstore transparency entry: 637357135
- Sigstore integration time:
-
Permalink:
andres-fr/skerch@fb9c6fba5856d4bb8fa91795ce1cd97595e93838 -
Branch / Tag:
refs/tags/1.2.0 - Owner: https://github.com/andres-fr
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
ci.yaml@fb9c6fba5856d4bb8fa91795ce1cd97595e93838 -
Trigger Event:
push
-
Statement type:
File details
Details for the file skerch-1.2.0-py3-none-any.whl.
File metadata
- Download URL: skerch-1.2.0-py3-none-any.whl
- Upload date:
- Size: 103.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
46e5bab4bbadc994a540a0ef04639f3a518f7877a12c6d9436adf10aae2b3bb1
|
|
| MD5 |
103a046fb6ea14987342d617438bbb70
|
|
| BLAKE2b-256 |
f1ad931294d24fa3602d462761682623813abfd82f3500f43c3203abc0c35cc6
|
Provenance
The following attestation bundles were made for skerch-1.2.0-py3-none-any.whl:
Publisher:
ci.yaml on andres-fr/skerch
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
skerch-1.2.0-py3-none-any.whl -
Subject digest:
46e5bab4bbadc994a540a0ef04639f3a518f7877a12c6d9436adf10aae2b3bb1 - Sigstore transparency entry: 637357154
- Sigstore integration time:
-
Permalink:
andres-fr/skerch@fb9c6fba5856d4bb8fa91795ce1cd97595e93838 -
Branch / Tag:
refs/tags/1.2.0 - Owner: https://github.com/andres-fr
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
ci.yaml@fb9c6fba5856d4bb8fa91795ce1cd97595e93838 -
Trigger Event:
push
-
Statement type: