scipy Linear operator implementations of the GGN and Hessian in PyTorch
Project description
scipy linear operators of deep learning matrices in PyTorch
This library implements
scipy.sparse.linalg.LinearOperator
s
for deep learning matrices, such as
- the Hessian
- the Fisher/generalized Gauss-Newton (GGN)
Matrix-vector products are carried out in PyTorch, i.e. potentially on a GPU.
The library supports defining these matrices not only on a mini-batch, but
on data sets (looping over batches during a matvec
operation).
You can plug these linear operators into scipy
, while carrying out the heavy
lifting (matrix-vector multiplies) in PyTorch on GPU. My favorite example for
such a routine is
scipy.sparse.linalg.eigsh
that lets you compute a subset of eigenpairs.
-
Documentation: https://curvlinops.readthedocs.io/en/latest/
-
Bug reports & feature requests: https://github.com/f-dangel/curvlinops/issues
Installation
pip install curvlinops-for-pytorch
Examples
Future ideas
Other features that could be supported in the future include:
-
Other matrices
- the un-centered gradient covariance (aka empirical Fisher)
- the centered gradient covariance
- terms of the hierarchical GGN decomposition
-
Block-diagonal approximations (via
param_groups
) -
Inverse matrix-vector products by solving a linear system via conjugate gradients
- This could allow computing generalization metrics like the Takeuchi Information Criterion (TIC), using inverse matrix-vector products in combination with Hutchinson trace estimation
Logo mage credits
- SciPy logo: Unknown, CC BY-SA 4.0, via Wikimedia Commons
- PyTorch logo: https://github.com/soumith, CC BY-SA 4.0, via Wikimedia Commons
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for curvlinops-for-pytorch-1.0.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | def7d58fbc157f5d1a5d3375729363997d0f1d80bfd97849b4122fd61bdc76bc |
|
MD5 | 721acaff2357ef56a03ab9ecbbff22af |
|
BLAKE2b-256 | 3299a19ceb7d342604e9555019b85614ee875e8c4b7eb34bb36e05b56286ae06 |
Hashes for curvlinops_for_pytorch-1.0.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | bf125870d1bdfde3d6e8303f5d651027d79ea3809e02eb2723eb88dfead827c4 |
|
MD5 | 68227580bb96da411c40890f502500b3 |
|
BLAKE2b-256 | f1abb6493c3018a8f766b3d82edbecee8b799c983daaa1e5ba6437c46d4bb4c5 |