Differentiable eigenvalue decomposition with JAX — CUDA 12.0 (GPU) build
Project description
Differentiable Generalized Eigenvalue Decomposition
Standalone implementation of differentiable eigenvalue decomposition with CPU (LAPACK) and GPU (cuSOLVER) backends. Extracted from pyscfad.
CPU and GPU wheels on PyPI for Linux and macOS (Apple Silicon), Python 3.10–3.13, JAX 0.5–0.10+. See Installation and Compatibility.
New
- Core code rewritten to be able to run on older cluster with JAX 0.4.XX for instance (most likely on GPU clusters).
- Builds for CUDA but I recommend just building from source, fast and easy with this package (it will work for your specific JAX version).
Features
- Generalized Problems:
A @ V = B @ V @ diag(W), etc. - JAX Integrated: Full support for
jit,vmap,grad, andjvp. - High Performance: Optimized LAPACK (CPU) and cuSOLVER (GPU) kernels.
- Precision:
float32/64andcomplex64/128. - Degeneracy Handling: Configurable
deg_threshfor stable gradients.
Installation & Quick Start
CPU
pip install eigh
Prebuilt CPU-only wheels — Linux (x86_64) and macOS (Apple Silicon), Python 3.10–3.13, JAX 0.5+.
GPU - Build from source (Recommended)
Make sure first that you have JAX installed that runs fine on your GPU.
Build from source for a jaxlib / CUDA / glibc combination the wheels don't
cover. The main case is an environment pinned to e.g., jaxlib 0.4.29 (the
prebuilt wheels require jaxlib ≥ 0.5 — the FFI binary ABI changed at 0.5, so a
0.5 wheel can't run on 0.4.x and vice-versa). The source builds against
whatever jaxlib is in your env (0.4.29 or 0.5–0.10+), CPU or GPU:
git clone https://github.com/Brogis1/eigh && cd eigh
pip install "scikit-build-core>=0.8" "nanobind>=1.0.0" cmake ninja
pip install . --no-build-isolation --no-deps
--no-build-isolationcompiles against thejaxlibalready in your env.--no-depskeeps your pinnedjax/jaxlib(essential on 0.4.29 — otherwise pip would upgrade it to ≥0.5).- For GPU, have
nvcconPATH(module load cuda/12.x); look forCUDA support enabledin the build log (CUDA not found⇒nvccnot onPATH). Plainjaxlib(no CUDA) yields a CPU-only build.
Full details and the why — FFI ABI, pinned-jaxlib clusters, nvcc paths, GPU
verification — are in docs/TECHNICAL_NOTES.md.
GPU (CUDA 12, Linux x86_64)
You can try this and may get lucky if it happens that JAX and other libraries match. I strongly recommend to build from source (see GPU - Build from source). Pick the package matching your cluster's CUDA version:
pip install eigh-cuda120 # CUDA 12.0+ (works through 12.8+); the safe default
pip install eigh-cuda128 # CUDA 12.8+ (newer toolchain / glibc 2.34)
Both bundle the cuSOLVER kernel + NVIDIA CUDA runtime libs; import eigh
auto-detects the GPU. They are separate packages from this same repo — import eigh is identical. See Compatibility.
Usage Example
import jax
import jax.numpy as jnp
# Gen. eigensolver from PySCFAD
from eigh import eigh, eigh_gen
jax.config.update("jax_enable_x64", True)
# Eigenvalue problem
A = jnp.array([[2., 1.], [1., 2.]])
B = jnp.array([[1., 1], [0.5, 1.]])
w1, v1 = eigh(A)
w2, v2 = eigh_gen(A, B)
# With gradients
grad1 = jax.grad(lambda A: eigh(A)[0].sum())(A)
grad2 = jax.grad(lambda A: eigh_gen(A, B)[0].sum())(A)
print("Eigenvalues:", w1, w2)
print("Eigenvectors:", v1, v2)
print("Gradients computed:", grad1.shape, grad2.shape)
Benchmarks
Forward/backward scaling vs. matrix size, and gradient stability as eigenvalues approach degeneracy — for the JAX eigensolvers in src/jax/. See benchmarks/suite/ for the scripts.
API Reference
eigh(a, b=None, *, lower=True, eigvals_only=False, type=1, deg_thresh=1e-9)Scipy-compatible interface.typesupports 1:A@v=B@v@λ, 2:A@B@v=v@λ, 3:B@A@v=v@λ.eigh_gen(a, b, *, lower=True, itype=1, deg_thresh=1e-9)Lower-level generalized solver.
Degenerate Eigenvalues & Gradients
Individual eigenvalue gradients are ill-defined for degenerate (repeated) eigenvalues. However, symmetric functions (like sum, var, trace) have stable gradients. The deg_thresh parameter (default 1e-9) masks divisions by near-zero gaps to maintain stability.
JAX Eigensolvers
A collection of differentiable generalized eigensolvers with different strategies for handling degenerate eigenvalues in reverse-mode gradients. Useful for training pipelines where degeneracies are common.
If you just want a working solver, use stable_eigh_pyscfad / stable_eigh_gen_pyscfad from generalized_eigensolver_pyscfad.py. They wrap the fast LAPACK/cuSOLVER kernels with a Lorentzian-broadened custom VJP, so gradients stay stable when eigenvalues are (nearly) degenerate.
On Windows, or if you cannot build the C++ kernels, use stable_generalized_eigh from generalized_eigensolver_stable.py instead — same gradient treatment, pure JAX.
The remaining solvers below are kept for benchmarking and for reproducing prior work; they are not recommended as defaults.
Recommended
| Solver | File | Strategy |
|---|---|---|
stable_eigh_pyscfad / stable_eigh_gen_pyscfad |
generalized_eigensolver_pyscfad.py | LAPACK/cuSOLVER kernels + Lorentzian-broadened VJP [2] |
stable_eigh / stable_generalized_eigh (pure-JAX) |
generalized_eigensolver_stable.py | Pure-JAX Cholesky + Lorentzian-broadened VJP [2] |
Alternative stable solvers
| Solver | File | Strategy | Gradient notes |
|---|---|---|---|
subspace_eigh |
generalized_eigensolver.py | Custom VJP: Lorentzian broadening F/(F²+ε²) [2] |
Stable |
subspace_generalized_eigh |
generalized_eigensolver.py | Symmetry-breaking perturbation + subspace_eigh [2,4] |
Stable |
degen_eigh |
generalized_eigensolver.py | Custom VJP: mask degenerate F_ij by threshold [1,3] |
Stable only for symmetric-subspace losses |
safe_generalized_eigh |
generalized_eigensolver.py | Cholesky + degen_eigh |
Inherits degen_eigh caveat |
Baselines (not gradient-safe at degeneracies)
| Solver | File | Strategy |
|---|---|---|
standard_eig |
generalized_eigensolver.py | scipy.linalg.eigh — non-differentiable reference |
jax_eig |
generalized_eigensolver.py | Plain Cholesky + jnp.linalg.eigh, default VJP |
generalized_eigh |
generalized_eigensolver.py | Symmetrized Cholesky with SPD shift, default VJP |
Compatibility
- Python: 3.10–3.13. JAX: 0.5 → 0.10+ for the prebuilt wheels; jax 0.4.29 via source build (its FFI ABI differs from 0.5+, so it needs its own build — see Build from source).
- CPU wheels: Linux x86_64 (
manylinux_2_28, bundled OpenBLAS) and macOS arm64. - GPU wheels: Linux x86_64, CUDA 12 —
eigh-cuda120(CUDA 12.0+, glibc 2.17) andeigh-cuda128(CUDA 12.8+, glibc 2.34). - Windows: no compiled wheel — use the pure-JAX solvers in src/jax/, or build from source.
Full detail — the FFI binary-ABI rules, why GPU ships as separate packages, the
abi3 wheel matrix, HPC/cluster notes, and how to build on a pinned old jaxlib —
is in docs/TECHNICAL_NOTES.md.
References
- [1] Kasim, M. F., & Vinko, S. M. Learning the exchange–correlation functional from nature with fully differentiable density functional theory. Phys. Rev. Lett. 127, 126403 (2021). https://doi.org/10.1103/PhysRevLett.127.126403
- [2] Colburn, S., & Majumdar, A. Inverse design and flexible parameterization of meta-optics using algorithmic differentiation. Communications Physics 4, 54 (2021). https://doi.org/10.1038/s42005-021-00568-6
- [3] JAX Issue #2748 — Differentiable
eighwith degeneracies. https://github.com/jax-ml/jax/issues/2748 - [4] JAX Issue #5461 — Stable generalized
eigh. https://github.com/jax-ml/jax/issues/5461
Development & Testing
- Requirements: CMake 3.18+, C++17, JAX, NumPy, LAPACK/CUDA.
- Tests:
pytest tests/test_eigh.py # Core functionality pytest tests/test_eigh_gen.py # Generalized itypes pytest tests/test_eigh_jit.py # JIT & vmap
- GPU Setup:
source setup_gpu_env_clean.sh ./run_gpu.sh python example_simple.py
License & Citation
Apache License 2.0. If used in research, please cite:
@software{sokolov2026eigh,
author={Sokolov, Igor},
title={Eigh: Differentiable eigenvalue decomposition with jax (cpu/gpu)},
url={https://github.com/Brogis1/eigh},
year={2026}
}
@software{pyscfad,
author = {Zhang, Xing},
title = {PySCFad: Automatic Differentiation for PySCF},
url = {https://github.com/fishjojo/pyscfad},
year = {2021-2025}
}
@article{10.1063/5.0118200,
author = {Zhang, Xing and Chan, Garnet Kin-Lic},
title = {Differentiable quantum chemistry with PySCF for molecules and materials at the mean-field level and beyond},
journal = {The Journal of Chemical Physics},
volume = {157},
number = {20},
pages = {204801},
year = {2022},
month = {11},
issn = {0021-9606},
doi = {10.1063/5.0118200},
url = {https://doi.org/10.1063/5.0118200},
}
@article{sokolov2026xc,
title = {Quantum-enhanced neural exchange-correlation functionals},
author = {Sokolov, Igor O. and Both, Gert-Jan and Bochevarov, Art D. and Dub, Pavel A. and Levine, Daniel S. and Brown, Christopher T. and Acheche, Shaheen and Barkoutsos, Panagiotis Kl. and Elfving, Vincent E.},
journal = {Phys. Rev. A},
volume = {113},
issue = {1},
pages = {012427},
numpages = {24},
year = {2026},
month = {Jan},
publisher = {American Physical Society},
doi = {10.1103/m51l-fys2},
url = {https://link.aps.org/doi/10.1103/m51l-fys2}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distributions
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file eigh_cuda120-0.4.1-cp312-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.
File metadata
- Download URL: eigh_cuda120-0.4.1-cp312-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- Upload date:
- Size: 10.5 MB
- Tags: CPython 3.12+, manylinux: glibc 2.17+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a03ff4ca01fe6b46b469106533284fcbf0c1bc5313da64604e469d09e30d8c1d
|
|
| MD5 |
219bd2e26ca3c2417aca425fb6e9dd9f
|
|
| BLAKE2b-256 |
c3f94d108af398f50c58a7afcd31486cd5d7da2f973c3e8b7e680f62e7d491e1
|
File details
Details for the file eigh_cuda120-0.4.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.
File metadata
- Download URL: eigh_cuda120-0.4.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- Upload date:
- Size: 10.5 MB
- Tags: CPython 3.11, manylinux: glibc 2.17+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
21f7bedb909d6d78de8615081868958831872c1e35e8b436d72a134f43c42064
|
|
| MD5 |
ecacd4b76bf94fda4d914bfc7458ebf5
|
|
| BLAKE2b-256 |
ab13713fe7c4fe50dcc5ee2c2b0d052644e4f6a12ccccca40f4c33505c1d5eb9
|
File details
Details for the file eigh_cuda120-0.4.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.
File metadata
- Download URL: eigh_cuda120-0.4.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- Upload date:
- Size: 10.5 MB
- Tags: CPython 3.10, manylinux: glibc 2.17+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d25d36018d721cc14f0dd5ff3390ece387ce420a66a13cdb3e32968514e6fff0
|
|
| MD5 |
f86b161fac2e965ff2f85b24c7654a05
|
|
| BLAKE2b-256 |
b6d4698c64fef9bc1f062233a7b123419ab4c993036739af9c81f8040b3f8bed
|