Differentiable eigenvalue decomposition with JAX (CPU/GPU)
Project description
Differentiable Generalized Eigenvalue Decomposition
Standalone implementation of differentiable eigenvalue decomposition with CPU (LAPACK) and GPU (cuSOLVER) backends. Extracted from pyscfad.
Wheels on PyPI: https://pypi.org/project/eigh/ — Linux (manylinux_2_28, x86_64) and macOS (x86_64, arm64), Python 3.10–3.12. GPU path (cuSOLVER) is tested locally; CI runs CPU tests only.
Windows: no prebuilt wheel. The pure-JAX solvers in src/jax/ (e.g. safe_generalized_eigh, subspace_generalized_eigh, stable_generalized_eigh) work out-of-the-box — pip install jax numpy scipy and import directly from that module. The fast LAPACK/cuSOLVER-backed eigh / eigh_gen kernels (and therefore stable_eigh_pyscfad / stable_eigh_gen_pyscfad) require building from source against a local BLAS/LAPACK.
Features
- Generalized Problems:
A @ V = B @ V @ diag(W), etc. - JAX Integrated: Full support for
jit,vmap,grad, andjvp. - High Performance: Optimized LAPACK (CPU) and cuSOLVER (GPU) kernels.
- Precision:
float32/64andcomplex64/128. - Degeneracy Handling: Configurable
deg_threshfor stable gradients.
Installation & Quick Start
# Install from source
pip install .
# For GPU support in this environment
pip install .[cuda-local]
Usage Example
import jax
import jax.numpy as jnp
# Gen. eigensolver from PySCFAD
from eigh import eigh
jax.config.update("jax_enable_x64", True)
A = jnp.array([[2., 1.], [1., 2.]])
w, v = eigh(A) # Standard
grad = jax.grad(lambda A: eigh(A)[0].sum())(A) # Differentiable
Benchmarks
Forward/backward scaling vs. matrix size, and gradient stability as eigenvalues approach degeneracy — for the JAX eigensolvers in src/jax/. See benchmarks/suite/ for the scripts.
API Reference
eigh(a, b=None, *, lower=True, eigvals_only=False, type=1, deg_thresh=1e-9)Scipy-compatible interface.typesupports 1:A@v=B@v@λ, 2:A@B@v=v@λ, 3:B@A@v=v@λ.eigh_gen(a, b, *, lower=True, itype=1, deg_thresh=1e-9)Lower-level generalized solver.
Degenerate Eigenvalues & Gradients
Individual eigenvalue gradients are ill-defined for degenerate (repeated) eigenvalues. However, symmetric functions (like sum, var, trace) have stable gradients. The deg_thresh parameter (default 1e-9) masks divisions by near-zero gaps to maintain stability.
JAX Eigensolvers (src/jax/)
A collection of differentiable generalized eigensolvers with different strategies for handling degenerate eigenvalues in reverse-mode gradients. Useful for training pipelines where degeneracies are common.
If you just want a working solver, use stable_eigh_pyscfad / stable_eigh_gen_pyscfad from generalized_eigensolver_pyscfad.py. They wrap the fast LAPACK/cuSOLVER kernels with a Lorentzian-broadened custom VJP, so gradients stay stable when eigenvalues are (nearly) degenerate.
On Windows, or if you cannot build the C++ kernels, use stable_generalized_eigh from generalized_eigensolver_stable.py instead — same gradient treatment, pure JAX.
The remaining solvers below are kept for benchmarking and for reproducing prior work; they are not recommended as defaults.
Recommended
| Solver | File | Strategy |
|---|---|---|
stable_eigh_pyscfad / stable_eigh_gen_pyscfad |
generalized_eigensolver_pyscfad.py | LAPACK/cuSOLVER kernels + Lorentzian-broadened VJP [2] |
stable_eigh / stable_generalized_eigh (pure-JAX) |
generalized_eigensolver_stable.py | Pure-JAX Cholesky + Lorentzian-broadened VJP [2] |
Alternative stable solvers
| Solver | File | Strategy | Gradient notes |
|---|---|---|---|
subspace_eigh |
generalized_eigensolver.py | Custom VJP: Lorentzian broadening F/(F²+ε²) [2] |
Stable |
subspace_generalized_eigh |
generalized_eigensolver.py | Symmetry-breaking perturbation + subspace_eigh [2,4] |
Stable |
degen_eigh |
generalized_eigensolver.py | Custom VJP: mask degenerate F_ij by threshold [1,3] |
Stable only for symmetric-subspace losses |
safe_generalized_eigh |
generalized_eigensolver.py | Cholesky + degen_eigh |
Inherits degen_eigh caveat |
Baselines (not gradient-safe at degeneracies)
| Solver | File | Strategy |
|---|---|---|
standard_eig |
generalized_eigensolver.py | scipy.linalg.eigh — non-differentiable reference |
jax_eig |
generalized_eigensolver.py | Plain Cholesky + jnp.linalg.eigh, default VJP |
generalized_eigh |
generalized_eigensolver.py | Symmetrized Cholesky with SPD shift, default VJP |
References
- [1] Kasim, M. F., & Vinko, S. M. Learning the exchange–correlation functional from nature with fully differentiable density functional theory. Phys. Rev. Lett. 127, 126403 (2021). https://doi.org/10.1103/PhysRevLett.127.126403
- [2] Colburn, S., & Majumdar, A. Inverse design and flexible parameterization of meta-optics using algorithmic differentiation. Communications Physics 4, 54 (2021). https://doi.org/10.1038/s42005-021-00568-6
- [3] JAX Issue #2748 — Differentiable
eighwith degeneracies. https://github.com/jax-ml/jax/issues/2748 - [4] JAX Issue #5461 — Stable generalized
eigh. https://github.com/jax-ml/jax/issues/5461
Development & Testing
- Requirements: CMake 3.18+, C++17, JAX, NumPy, LAPACK/CUDA.
- Tests:
pytest tests/test_eigh.py # Core functionality pytest tests/test_eigh_gen.py # Generalized itypes pytest tests/test_eigh_jit.py # JIT & vmap
- GPU Setup:
source setup_gpu_env_clean.sh ./run_gpu.sh python example_simple.py
License & Citation
Apache License 2.0. If used in research, please cite:
@software{pyscfad,
author = {Zhang, Xing},
title = {PySCFad: Automatic Differentiation for PySCF},
url = {https://github.com/fishjojo/pyscfad},
year = {2021-2025}
}
@software{sokolov2026eigh,
title={Eigh: Differentiable eigenvalue decomposition with jax (cpu/gpu)},
author={Sokolov, Igor},
url={https://github.com/Brogis1/eigh},
year={2026}
}
@article{sokolov2026xc,
title = {Quantum-enhanced neural exchange-correlation functionals},
author = {Sokolov, Igor O. and Both, Gert-Jan and Bochevarov, Art D. and Dub, Pavel A. and Levine, Daniel S. and Brown, Christopher T. and Acheche, Shaheen and Barkoutsos, Panagiotis Kl. and Elfving, Vincent E.},
journal = {Phys. Rev. A},
volume = {113},
issue = {1},
pages = {012427},
numpages = {24},
year = {2026},
month = {Jan},
publisher = {American Physical Society},
doi = {10.1103/m51l-fys2},
url = {https://link.aps.org/doi/10.1103/m51l-fys2}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file eigh-0.2.2.tar.gz.
File metadata
- Download URL: eigh-0.2.2.tar.gz
- Upload date:
- Size: 1.2 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0eb44cdc51b6b7612689e51e17b161b0ff79495f03a9740431eb30c0a0726223
|
|
| MD5 |
877656bc574f6f27498888971f2cf276
|
|
| BLAKE2b-256 |
93de61452637dd517d818ca7a6df045e19139786d62f7390ad5003908f9b70e9
|
File details
Details for the file eigh-0.2.2-cp313-cp313-manylinux_2_28_x86_64.whl.
File metadata
- Download URL: eigh-0.2.2-cp313-cp313-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 12.0 MB
- Tags: CPython 3.13, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d02b3cc0e36342ddcb720522fddf6647394d70b1c73e70b32f2315d08ba8b2a0
|
|
| MD5 |
cc8785374512d4ccae0fa5fe4d312004
|
|
| BLAKE2b-256 |
d341004213a83cad76a58edfb2081e005065610511070b308edb59f369d0b91e
|
File details
Details for the file eigh-0.2.2-cp313-cp313-macosx_11_0_arm64.whl.
File metadata
- Download URL: eigh-0.2.2-cp313-cp313-macosx_11_0_arm64.whl
- Upload date:
- Size: 60.2 kB
- Tags: CPython 3.13, macOS 11.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
22acb0c6e5ea47ce1d190a09010d4df39ceb621c4bcd6786475a95afc78c0615
|
|
| MD5 |
aedfe6b4cff001ce7aeb9de47d3064dd
|
|
| BLAKE2b-256 |
b966549c35469c83284f633b5a129563d48f0ee85b90a65dd5b5b26899639033
|
File details
Details for the file eigh-0.2.2-cp313-cp313-macosx_10_15_x86_64.whl.
File metadata
- Download URL: eigh-0.2.2-cp313-cp313-macosx_10_15_x86_64.whl
- Upload date:
- Size: 63.8 kB
- Tags: CPython 3.13, macOS 10.15+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e356c074c07a72b228ce530cfc980f8ba66d1b2c97067934c251711e370d57b0
|
|
| MD5 |
892969a2062cc18e7e9c7d6122c4d7f1
|
|
| BLAKE2b-256 |
2074b19529594b6e6dbe626e0193bb5508269a66f1e5724370db487e881ea93f
|
File details
Details for the file eigh-0.2.2-cp312-cp312-manylinux_2_28_x86_64.whl.
File metadata
- Download URL: eigh-0.2.2-cp312-cp312-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 12.0 MB
- Tags: CPython 3.12, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3e3fc208e7429106fce087e4942095868a36cb0daf0a4eff6a4ea91edde6922a
|
|
| MD5 |
d80e99ca5756bd65b17a4607d30272c2
|
|
| BLAKE2b-256 |
5a3d4185af6773779a25ecb468165fc01fb5367110cc06821deae5ba552a1326
|
File details
Details for the file eigh-0.2.2-cp312-cp312-macosx_11_0_arm64.whl.
File metadata
- Download URL: eigh-0.2.2-cp312-cp312-macosx_11_0_arm64.whl
- Upload date:
- Size: 60.2 kB
- Tags: CPython 3.12, macOS 11.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bede6edad96f9d56654590bbf83e30b6f97244201f447fa94b07114c9b70f717
|
|
| MD5 |
43c42d78d0fb9151c2cf53bbe5a26b5f
|
|
| BLAKE2b-256 |
98fceeafd38653c060a869ef99381502143e56a723c7e006435347f23d0f45d5
|
File details
Details for the file eigh-0.2.2-cp312-cp312-macosx_10_15_x86_64.whl.
File metadata
- Download URL: eigh-0.2.2-cp312-cp312-macosx_10_15_x86_64.whl
- Upload date:
- Size: 63.8 kB
- Tags: CPython 3.12, macOS 10.15+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
791aafbe605c43b0a3ff95714875936732b016b6f81ac75c1b6734c61772ab89
|
|
| MD5 |
3c80c5a64381d284575a4241112ee765
|
|
| BLAKE2b-256 |
706fb1fce3fd83ce3eb7ef195181e83663c386c09c5df5f41da28e45b058bc5c
|
File details
Details for the file eigh-0.2.2-cp311-cp311-manylinux_2_28_x86_64.whl.
File metadata
- Download URL: eigh-0.2.2-cp311-cp311-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 12.0 MB
- Tags: CPython 3.11, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e53969d09116fdf3d1905f619335d38332cbf119bc379054a75234be34da2cbd
|
|
| MD5 |
c9674e86fc70305d9427648168ca41e1
|
|
| BLAKE2b-256 |
06411fef6ecf73d1a0e49d8fd8c1bea28caaa92cd9bdeb9fadc5207b54ab9e73
|
File details
Details for the file eigh-0.2.2-cp311-cp311-macosx_11_0_arm64.whl.
File metadata
- Download URL: eigh-0.2.2-cp311-cp311-macosx_11_0_arm64.whl
- Upload date:
- Size: 60.6 kB
- Tags: CPython 3.11, macOS 11.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f3af526e7ba2698bcdf2b3db39c7fd1af00c51a467b52e36b51bebfd4a7f33c9
|
|
| MD5 |
133c11d0156009a434bf0efd5756f533
|
|
| BLAKE2b-256 |
f8617301064b164a199823eba9dce58047fb72ef8e56a0877821e470cf9e29b9
|
File details
Details for the file eigh-0.2.2-cp311-cp311-macosx_10_15_x86_64.whl.
File metadata
- Download URL: eigh-0.2.2-cp311-cp311-macosx_10_15_x86_64.whl
- Upload date:
- Size: 64.1 kB
- Tags: CPython 3.11, macOS 10.15+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ac98a1503f6ee6875d2c2451ef4ed9c38678cfebf55b62651c70e0d2eb86b158
|
|
| MD5 |
620a28ca2e85ac3c9a00e76ba9c728da
|
|
| BLAKE2b-256 |
0809bc33d24711add121a97f13b1c6c6332ea67aeed3721eebc45a687269a4f9
|
File details
Details for the file eigh-0.2.2-cp310-cp310-manylinux_2_28_x86_64.whl.
File metadata
- Download URL: eigh-0.2.2-cp310-cp310-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 12.0 MB
- Tags: CPython 3.10, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8c79741f59863594fe9d3f1f0cbec288c996e5c7254721c5be694b5f9af50596
|
|
| MD5 |
386f62d9f0b9ca9189203ed566e2844b
|
|
| BLAKE2b-256 |
d025e35169563f79ec7dbefadb94cefc5fdb10fcd5053d392a5df76bd51aa829
|
File details
Details for the file eigh-0.2.2-cp310-cp310-macosx_11_0_arm64.whl.
File metadata
- Download URL: eigh-0.2.2-cp310-cp310-macosx_11_0_arm64.whl
- Upload date:
- Size: 60.4 kB
- Tags: CPython 3.10, macOS 11.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c84d82d2c8a0f208aa0cb0e35e429190067644605df8bbe7b88a3f469c3b2b9b
|
|
| MD5 |
603448e1e045373a6c950892a3404247
|
|
| BLAKE2b-256 |
399c96550df4b5067f962db0168d887d4676c1b73aa0d6894828b4a53d95b229
|
File details
Details for the file eigh-0.2.2-cp310-cp310-macosx_10_15_x86_64.whl.
File metadata
- Download URL: eigh-0.2.2-cp310-cp310-macosx_10_15_x86_64.whl
- Upload date:
- Size: 63.8 kB
- Tags: CPython 3.10, macOS 10.15+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
de2e87eff73773018f5e375f9f916d864cc93846c5c0138b4ef8bc25dc9d0f60
|
|
| MD5 |
2c8afeee83d526fa0d2a97574d67e7ae
|
|
| BLAKE2b-256 |
0cb28ca3d8a80acb89919591410807eb4d22ea6b2589d37452df91955622a1b6
|