Skip to main content

scipy Linear operators for curvature matrices in PyTorch

Project description

Logo scipy linear operators of deep learning matrices in PyTorch

Python 3.8+ tests Coveralls

This library implements scipy.sparse.linalg.LinearOperators for deep learning matrices, such as

  • the Hessian
  • the Fisher/generalized Gauss-Newton (GGN)
  • the Monte-Carlo approximated Fisher
  • the Fisher/GGN's KFAC approximation (Kronecker-Factored Approximate Curvature)
  • the uncentered gradient covariance (aka empirical Fisher)
  • the output-parameter Jacobian of a neural net and its transpose

Matrix-vector products are carried out in PyTorch, i.e. potentially on a GPU. The library supports defining these matrices not only on a mini-batch, but on data sets (looping over batches during a matvec operation).

You can plug these linear operators into scipy, while carrying out the heavy lifting (matrix-vector multiplies) in PyTorch on GPU. My favorite example for such a routine is scipy.sparse.linalg.eigsh that lets you compute a subset of eigen-pairs.

The library also provides linear operator transformations, like taking the inverse (inverse matrix-vector product via conjugate gradients) or slicing out sub-matrices.

Finally, it offers functionality to probe properties of the represented matrices, like their spectral density, trace, or diagonal.

Installation

pip install curvlinops-for-pytorch

Examples

Future ideas

Other features that could be supported in the future include:

Logo mage credits

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

curvlinops-for-pytorch-1.2.0.tar.gz (112.0 kB view details)

Uploaded Source

Built Distribution

curvlinops_for_pytorch-1.2.0-py3-none-any.whl (69.5 kB view details)

Uploaded Python 3

File details

Details for the file curvlinops-for-pytorch-1.2.0.tar.gz.

File metadata

File hashes

Hashes for curvlinops-for-pytorch-1.2.0.tar.gz
Algorithm Hash digest
SHA256 d095d598397397be65e1000332854158b68fac6f992c4b396333c31924a2427d
MD5 81342248231e2c36e92e55ce96d9bce1
BLAKE2b-256 9ea23f2afec17f1c72164aae500440f8796b5c9d7229c44588ab7244689ea004

See more details on using hashes here.

Provenance

File details

Details for the file curvlinops_for_pytorch-1.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for curvlinops_for_pytorch-1.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4e216d5cd33f882d6166a518fa81c195ae328affcb8b092cc87b2a16eb045442
MD5 7f148162882aa8b6796f860b35650a75
BLAKE2b-256 cbf2dccac6ee3ca5310076c4d4a28274dad5e86254c3a99a29cec0345a12be3e

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page