scipy Linear operators for curvature matrices in PyTorch
Project description
scipy linear operators of deep learning matrices in PyTorch
This library implements
scipy.sparse.linalg.LinearOperator
s
for deep learning matrices, such as
- the Hessian
- the Fisher/generalized Gauss-Newton (GGN)
- the Monte-Carlo approximated Fisher
- the Fisher/GGN's KFAC approximation (Kronecker-Factored Approximate Curvature)
- the uncentered gradient covariance (aka empirical Fisher)
- the output-parameter Jacobian of a neural net and its transpose
Matrix-vector products are carried out in PyTorch, i.e. potentially on a GPU.
The library supports defining these matrices not only on a mini-batch, but
on data sets (looping over batches during a matvec
operation).
You can plug these linear operators into scipy
, while carrying out the heavy
lifting (matrix-vector multiplies) in PyTorch on GPU. My favorite example for
such a routine is
scipy.sparse.linalg.eigsh
that lets you compute a subset of eigen-pairs.
The library also provides linear operator transformations, like taking the inverse (inverse matrix-vector product via conjugate gradients) or slicing out sub-matrices.
Finally, it offers functionality to probe properties of the represented matrices, like their spectral density, trace, or diagonal.
-
Documentation: https://curvlinops.readthedocs.io/en/latest/
-
Bug reports & feature requests: https://github.com/f-dangel/curvlinops/issues
Installation
pip install curvlinops-for-pytorch
Examples
Future ideas
Other features that could be supported in the future include:
-
Other matrices
- the centered gradient covariance
- terms of the hierarchical GGN decomposition
Logo mage credits
- SciPy logo: Unknown, CC BY-SA 4.0, via Wikimedia Commons
- PyTorch logo: https://github.com/soumith, CC BY-SA 4.0, via Wikimedia Commons
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for curvlinops-for-pytorch-1.2.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | d095d598397397be65e1000332854158b68fac6f992c4b396333c31924a2427d |
|
MD5 | 81342248231e2c36e92e55ce96d9bce1 |
|
BLAKE2b-256 | 9ea23f2afec17f1c72164aae500440f8796b5c9d7229c44588ab7244689ea004 |
Hashes for curvlinops_for_pytorch-1.2.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4e216d5cd33f882d6166a518fa81c195ae328affcb8b092cc87b2a16eb045442 |
|
MD5 | 7f148162882aa8b6796f860b35650a75 |
|
BLAKE2b-256 | cbf2dccac6ee3ca5310076c4d4a28274dad5e86254c3a99a29cec0345a12be3e |