Skip to main content

scipy Linear operators for curvature matrices in PyTorch

Project description

Logo scipy linear operators of deep learning matrices in PyTorch

Python 3.8+ tests Coveralls

This library implements scipy.sparse.linalg.LinearOperators for deep learning matrices, such as

  • the Hessian
  • the Fisher/generalized Gauss-Newton (GGN)
  • the Monte-Carlo approximated Fisher
  • the Fisher/GGN's KFAC approximation (Kronecker-Factored Approximate Curvature)
  • the uncentered gradient covariance (aka empirical Fisher)
  • the output-parameter Jacobian of a neural net and its transpose

Matrix-vector products are carried out in PyTorch, i.e. potentially on a GPU. The library supports defining these matrices not only on a mini-batch, but on data sets (looping over batches during a matvec operation).

You can plug these linear operators into scipy, while carrying out the heavy lifting (matrix-vector multiplies) in PyTorch on GPU. My favorite example for such a routine is scipy.sparse.linalg.eigsh that lets you compute a subset of eigen-pairs.

The library also provides linear operator transformations, like taking the inverse (inverse matrix-vector product via conjugate gradients) or slicing out sub-matrices.

Finally, it offers functionality to probe properties of the represented matrices, like their spectral density, trace, or diagonal.

Installation

pip install curvlinops-for-pytorch

Examples

Future ideas

Other features that could be supported in the future include:

Logo mage credits

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

curvlinops_for_pytorch-2.0.1.tar.gz (144.6 kB view details)

Uploaded Source

Built Distribution

curvlinops_for_pytorch-2.0.1-py3-none-any.whl (67.4 kB view details)

Uploaded Python 3

File details

Details for the file curvlinops_for_pytorch-2.0.1.tar.gz.

File metadata

File hashes

Hashes for curvlinops_for_pytorch-2.0.1.tar.gz
Algorithm Hash digest
SHA256 2028a0542f50c40e687137930180dbb1ff87f0b798adab5d9e62b2da81b82da3
MD5 dc98c6e650cc8e23c00627daf11fad1e
BLAKE2b-256 182aa75ee625297e07080051c4c5424b5e2298cef14245cd24b65a35337b61ee

See more details on using hashes here.

File details

Details for the file curvlinops_for_pytorch-2.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for curvlinops_for_pytorch-2.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a54dca3614352d2ec78fc57cd45fcdbd29d2fd3b793d39f9e7a5bf9e52be06c5
MD5 d0f404255545bb6a7fa41b7e72f20269
BLAKE2b-256 fc4423c972a229d3be41094d4ef6493156fba63095fd151366f5ac1480cb8557

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page