Various eigendecomposition implementations wrapped for jax.
Project description
jeig - Eigendecompositions wrapped for jax
v0.0.0
Overview
This package wraps eigendecompositions as provided by jax, numpy, scipy, and torch for use with jax. Depending upon your system and your versions of these packages, you may observe significant speed differences.
The wrapped eig
function also includes a custom vjp rule so that gradients with respect to eigenvalues and eigenvectors can be computed.
Example usage
import jax
import jeig.eig as jeig
matrix = jax.random.normal(jax.random.PRNGKey(0), (8, 320, 320))
jeig.BACKEND_EIG = jeig.JAX
%timeit jeig.eig(matrix)
jeig.BACKEND_EIG = jeig.NUMPY
%timeit jeig.eig(matrix)
jeig.BACKEND_EIG = jeig.SCIPY
%timeit jeig.eig(matrix)
jeig.BACKEND_EIG = jeig.TORCH
%timeit jeig.eig(matrix)
376 ms ± 11.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
689 ms ± 11.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
414 ms ± 19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
136 ms ± 4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Credit
The high-level eig
function and the tests are adapted from fmmax. The torch implementation of eigendecomposition is due to a comment by @YouJiacheng.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jeig-0.0.0.tar.gz
(7.5 kB
view hashes)
Built Distribution
jeig-0.0.0-py3-none-any.whl
(6.5 kB
view hashes)