scipy Linear operator implementations of the GGN and Hessian in PyTorch
Project description
scipy linear operators of deep learning matrices in PyTorch
This library implements
scipy.sparse.linalg.LinearOperator
s
for deep learning matrices, such as
- the Hessian
- the Fisher/generalized Gauss-Newton (GGN)
- the Monte-Carlo approximated Fisher
- the uncentered gradient covariance (aka empirical Fisher)
Matrix-vector products are carried out in PyTorch, i.e. potentially on a GPU.
The library supports defining these matrices not only on a mini-batch, but
on data sets (looping over batches during a matvec
operation).
You can plug these linear operators into scipy
, while carrying out the heavy
lifting (matrix-vector multiplies) in PyTorch on GPU. My favorite example for
such a routine is
scipy.sparse.linalg.eigsh
that lets you compute a subset of eigenpairs.
-
Documentation: https://curvlinops.readthedocs.io/en/latest/
-
Bug reports & feature requests: https://github.com/f-dangel/curvlinops/issues
Installation
pip install curvlinops-for-pytorch
Examples
Future ideas
Other features that could be supported in the future include:
-
Other matrices
- the centered gradient covariance
- terms of the hierarchical GGN decomposition
-
Block-diagonal approximations (via
param_groups
)
Logo mage credits
- SciPy logo: Unknown, CC BY-SA 4.0, via Wikimedia Commons
- PyTorch logo: https://github.com/soumith, CC BY-SA 4.0, via Wikimedia Commons
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for curvlinops-for-pytorch-1.1.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | b061ababf8755ff2a00a7bdb5e056433ea73b66fbf19ac84e1ed4893209542c3 |
|
MD5 | e5dcb94de66d52d62b7c1b23ee05d6b4 |
|
BLAKE2b-256 | 421e9ca7e34c3a72d0512eff7a9f632a724ae6443039ad80eca0d38d131f009b |
Hashes for curvlinops_for_pytorch-1.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 61a630705a7093eb5c236728285ea6fc0bb79371f94cc61ebaf5ccf1cabd084a |
|
MD5 | e39205477a04f6183add19699ddd542a |
|
BLAKE2b-256 | 8a11fd230ba54d84933deb9520db9681144ff252494137fe09c9c9b899b188ea |