scipy Linear operator implementations of the GGN and Hessian in PyTorch
Project description
scipy linear operators of deep learning matrices in PyTorch
This library implements
scipy.sparse.linalg.LinearOperator
s
for deep learning matrices, such as
- the Hessian
- the Fisher/generalized Gauss-Newton (GGN)
Matrix-vector products are carried out in PyTorch, i.e. potentially on a GPU.
The library supports defining these matrices not only on a mini-batch, but
on data sets (looping over batches during a matvec
operation).
You can plug these linear operators into scipy
, while carrying out the heavy
lifting (matrix-vector multiplies) in PyTorch on GPU. My favorite example for
such a routine is
scipy.sparse.linalg.eigsh
that lets you compute a subset of eigenpairs.
-
Documentation: https://curvlinops.readthedocs.io/en/latest/
-
Bug reports & feature requests: https://github.com/f-dangel/curvlinops/issues
Installation
pip install curvlinops-for-pytorch
Examples
Future ideas
Other features that could be supported in the future include:
-
Other matrices
- the un-centered gradient covariance (aka empirical Fisher)
- the centered gradient covariance
- terms of the hierarchical GGN decomposition
-
Block-diagonal approximations (via
param_groups
) -
Inverse matrix-vector products by solving a linear system via conjugate gradients
- This could allow computing generalization metrics like the Takeuchi Information Criterion (TIC), using inverse matrix-vector products in combination with Hutchinson trace estimation
Logo mage credits
- SciPy logo: Unknown, CC BY-SA 4.0, via Wikimedia Commons
- PyTorch logo: https://github.com/soumith, CC BY-SA 4.0, via Wikimedia Commons
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for curvlinops-for-pytorch-1.0.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | def7d58fbc157f5d1a5d3375729363997d0f1d80bfd97849b4122fd61bdc76bc |
|
MD5 | 721acaff2357ef56a03ab9ecbbff22af |
|
BLAKE2b-256 | 3299a19ceb7d342604e9555019b85614ee875e8c4b7eb34bb36e05b56286ae06 |
Hashes for curvlinops_for_pytorch-1.0.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | bf125870d1bdfde3d6e8303f5d651027d79ea3809e02eb2723eb88dfead827c4 |
|
MD5 | 68227580bb96da411c40890f502500b3 |
|
BLAKE2b-256 | f1abb6493c3018a8f766b3d82edbecee8b799c983daaa1e5ba6437c46d4bb4c5 |