Skip to main content

Differentiable eigenvalue decomposition with JAX (CPU/GPU)

Project description

Differentiable Generalized Eigenvalue Decomposition

Tests

Eigh Logo

Standalone implementation of differentiable eigenvalue decomposition with CPU (LAPACK) and GPU (cuSOLVER) backends. Extracted from pyscfad.

Wheels on PyPI: https://pypi.org/project/eigh/ — Linux (manylinux_2_28, x86_64) and macOS (x86_64, arm64), Python 3.10–3.12. GPU path (cuSOLVER) is tested locally; CI runs CPU tests only.

Windows: no prebuilt wheel. The pure-JAX solvers in src/jax/ (e.g. safe_generalized_eigh, subspace_generalized_eigh, stable_generalized_eigh) work out-of-the-box — pip install jax numpy scipy and import directly from that module. The fast LAPACK/cuSOLVER-backed eigh / eigh_gen kernels (and therefore stable_eigh_pyscfad / stable_eigh_gen_pyscfad) require building from source against a local BLAS/LAPACK.

Features

  • Generalized Problems: A @ V = B @ V @ diag(W), etc.
  • JAX Integrated: Full support for jit, vmap, grad, and jvp.
  • High Performance: Optimized LAPACK (CPU) and cuSOLVER (GPU) kernels.
  • Precision: float32/64 and complex64/128.
  • Degeneracy Handling: Configurable deg_thresh for stable gradients.

Installation & Quick Start

# Install from source
pip install .

# For GPU support in this environment
pip install .[cuda-local]

Usage Example

import jax
import jax.numpy as jnp
# Gen. eigensolver from PySCFAD
from eigh import eigh

jax.config.update("jax_enable_x64", True)
A = jnp.array([[2., 1.], [1., 2.]])
w, v = eigh(A) # Standard
grad = jax.grad(lambda A: eigh(A)[0].sum())(A) # Differentiable

Benchmarks

Forward/backward scaling vs. matrix size, and gradient stability as eigenvalues approach degeneracy — for the JAX eigensolvers in src/jax/. See benchmarks/suite/ for the scripts.

Forward-pass scaling Backward-pass (gradient) scaling

API Reference

  • eigh(a, b=None, *, lower=True, eigvals_only=False, type=1, deg_thresh=1e-9) Scipy-compatible interface. type supports 1: A@v=B@v@λ, 2: A@B@v=v@λ, 3: B@A@v=v@λ.
  • eigh_gen(a, b, *, lower=True, itype=1, deg_thresh=1e-9) Lower-level generalized solver.

Degenerate Eigenvalues & Gradients

Individual eigenvalue gradients are ill-defined for degenerate (repeated) eigenvalues. However, symmetric functions (like sum, var, trace) have stable gradients. The deg_thresh parameter (default 1e-9) masks divisions by near-zero gaps to maintain stability.

JAX Eigensolvers (src/jax/)

A collection of differentiable generalized eigensolvers with different strategies for handling degenerate eigenvalues in reverse-mode gradients. Useful for training pipelines where degeneracies are common.

If you just want a working solver, use stable_eigh_pyscfad / stable_eigh_gen_pyscfad from generalized_eigensolver_pyscfad.py. They wrap the fast LAPACK/cuSOLVER kernels with a Lorentzian-broadened custom VJP, so gradients stay stable when eigenvalues are (nearly) degenerate.

On Windows, or if you cannot build the C++ kernels, use stable_generalized_eigh from generalized_eigensolver_stable.py instead — same gradient treatment, pure JAX.

The remaining solvers below are kept for benchmarking and for reproducing prior work; they are not recommended as defaults.

Recommended

Solver File Strategy
stable_eigh_pyscfad / stable_eigh_gen_pyscfad generalized_eigensolver_pyscfad.py LAPACK/cuSOLVER kernels + Lorentzian-broadened VJP [2]
stable_eigh / stable_generalized_eigh (pure-JAX) generalized_eigensolver_stable.py Pure-JAX Cholesky + Lorentzian-broadened VJP [2]

Alternative stable solvers

Solver File Strategy Gradient notes
subspace_eigh generalized_eigensolver.py Custom VJP: Lorentzian broadening F/(F²+ε²) [2] Stable
subspace_generalized_eigh generalized_eigensolver.py Symmetry-breaking perturbation + subspace_eigh [2,4] Stable
degen_eigh generalized_eigensolver.py Custom VJP: mask degenerate F_ij by threshold [1,3] Stable only for symmetric-subspace losses
safe_generalized_eigh generalized_eigensolver.py Cholesky + degen_eigh Inherits degen_eigh caveat

Baselines (not gradient-safe at degeneracies)

Solver File Strategy
standard_eig generalized_eigensolver.py scipy.linalg.eigh — non-differentiable reference
jax_eig generalized_eigensolver.py Plain Cholesky + jnp.linalg.eigh, default VJP
generalized_eigh generalized_eigensolver.py Symmetrized Cholesky with SPD shift, default VJP

References

Development & Testing

  • Requirements: CMake 3.18+, C++17, JAX, NumPy, LAPACK/CUDA.
  • Tests:
    pytest tests/test_eigh.py     # Core functionality
    pytest tests/test_eigh_gen.py # Generalized itypes
    pytest tests/test_eigh_jit.py # JIT & vmap
    
  • GPU Setup:
    source setup_gpu_env_clean.sh
    ./run_gpu.sh python example_simple.py
    

License & Citation

Apache License 2.0. If used in research, please cite:

@software{pyscfad,
  author = {Zhang, Xing},
  title = {PySCFad: Automatic Differentiation for PySCF},
  url = {https://github.com/fishjojo/pyscfad},
  year = {2021-2025}
}

@software{sokolov2026eigh,
  title={Eigh: Differentiable eigenvalue decomposition with jax (cpu/gpu)},
  author={Sokolov, Igor},
  url={https://github.com/Brogis1/eigh},
  year={2026}
}

@article{sokolov2026xc,
  title = {Quantum-enhanced neural exchange-correlation functionals},
  author = {Sokolov, Igor O. and Both, Gert-Jan and Bochevarov, Art D. and Dub, Pavel A. and Levine, Daniel S. and Brown, Christopher T. and Acheche, Shaheen and Barkoutsos, Panagiotis Kl. and Elfving, Vincent E.},
  journal = {Phys. Rev. A},
  volume = {113},
  issue = {1},
  pages = {012427},
  numpages = {24},
  year = {2026},
  month = {Jan},
  publisher = {American Physical Society},
  doi = {10.1103/m51l-fys2},
  url = {https://link.aps.org/doi/10.1103/m51l-fys2}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

eigh-0.2.2.tar.gz (1.2 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

eigh-0.2.2-cp313-cp313-manylinux_2_28_x86_64.whl (12.0 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.28+ x86-64

eigh-0.2.2-cp313-cp313-macosx_11_0_arm64.whl (60.2 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

eigh-0.2.2-cp313-cp313-macosx_10_15_x86_64.whl (63.8 kB view details)

Uploaded CPython 3.13macOS 10.15+ x86-64

eigh-0.2.2-cp312-cp312-manylinux_2_28_x86_64.whl (12.0 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.28+ x86-64

eigh-0.2.2-cp312-cp312-macosx_11_0_arm64.whl (60.2 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

eigh-0.2.2-cp312-cp312-macosx_10_15_x86_64.whl (63.8 kB view details)

Uploaded CPython 3.12macOS 10.15+ x86-64

eigh-0.2.2-cp311-cp311-manylinux_2_28_x86_64.whl (12.0 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.28+ x86-64

eigh-0.2.2-cp311-cp311-macosx_11_0_arm64.whl (60.6 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

eigh-0.2.2-cp311-cp311-macosx_10_15_x86_64.whl (64.1 kB view details)

Uploaded CPython 3.11macOS 10.15+ x86-64

eigh-0.2.2-cp310-cp310-manylinux_2_28_x86_64.whl (12.0 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.28+ x86-64

eigh-0.2.2-cp310-cp310-macosx_11_0_arm64.whl (60.4 kB view details)

Uploaded CPython 3.10macOS 11.0+ ARM64

eigh-0.2.2-cp310-cp310-macosx_10_15_x86_64.whl (63.8 kB view details)

Uploaded CPython 3.10macOS 10.15+ x86-64

File details

Details for the file eigh-0.2.2.tar.gz.

File metadata

  • Download URL: eigh-0.2.2.tar.gz
  • Upload date:
  • Size: 1.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for eigh-0.2.2.tar.gz
Algorithm Hash digest
SHA256 0eb44cdc51b6b7612689e51e17b161b0ff79495f03a9740431eb30c0a0726223
MD5 877656bc574f6f27498888971f2cf276
BLAKE2b-256 93de61452637dd517d818ca7a6df045e19139786d62f7390ad5003908f9b70e9

See more details on using hashes here.

File details

Details for the file eigh-0.2.2-cp313-cp313-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for eigh-0.2.2-cp313-cp313-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 d02b3cc0e36342ddcb720522fddf6647394d70b1c73e70b32f2315d08ba8b2a0
MD5 cc8785374512d4ccae0fa5fe4d312004
BLAKE2b-256 d341004213a83cad76a58edfb2081e005065610511070b308edb59f369d0b91e

See more details on using hashes here.

File details

Details for the file eigh-0.2.2-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for eigh-0.2.2-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 22acb0c6e5ea47ce1d190a09010d4df39ceb621c4bcd6786475a95afc78c0615
MD5 aedfe6b4cff001ce7aeb9de47d3064dd
BLAKE2b-256 b966549c35469c83284f633b5a129563d48f0ee85b90a65dd5b5b26899639033

See more details on using hashes here.

File details

Details for the file eigh-0.2.2-cp313-cp313-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for eigh-0.2.2-cp313-cp313-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 e356c074c07a72b228ce530cfc980f8ba66d1b2c97067934c251711e370d57b0
MD5 892969a2062cc18e7e9c7d6122c4d7f1
BLAKE2b-256 2074b19529594b6e6dbe626e0193bb5508269a66f1e5724370db487e881ea93f

See more details on using hashes here.

File details

Details for the file eigh-0.2.2-cp312-cp312-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for eigh-0.2.2-cp312-cp312-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 3e3fc208e7429106fce087e4942095868a36cb0daf0a4eff6a4ea91edde6922a
MD5 d80e99ca5756bd65b17a4607d30272c2
BLAKE2b-256 5a3d4185af6773779a25ecb468165fc01fb5367110cc06821deae5ba552a1326

See more details on using hashes here.

File details

Details for the file eigh-0.2.2-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for eigh-0.2.2-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 bede6edad96f9d56654590bbf83e30b6f97244201f447fa94b07114c9b70f717
MD5 43c42d78d0fb9151c2cf53bbe5a26b5f
BLAKE2b-256 98fceeafd38653c060a869ef99381502143e56a723c7e006435347f23d0f45d5

See more details on using hashes here.

File details

Details for the file eigh-0.2.2-cp312-cp312-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for eigh-0.2.2-cp312-cp312-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 791aafbe605c43b0a3ff95714875936732b016b6f81ac75c1b6734c61772ab89
MD5 3c80c5a64381d284575a4241112ee765
BLAKE2b-256 706fb1fce3fd83ce3eb7ef195181e83663c386c09c5df5f41da28e45b058bc5c

See more details on using hashes here.

File details

Details for the file eigh-0.2.2-cp311-cp311-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for eigh-0.2.2-cp311-cp311-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 e53969d09116fdf3d1905f619335d38332cbf119bc379054a75234be34da2cbd
MD5 c9674e86fc70305d9427648168ca41e1
BLAKE2b-256 06411fef6ecf73d1a0e49d8fd8c1bea28caaa92cd9bdeb9fadc5207b54ab9e73

See more details on using hashes here.

File details

Details for the file eigh-0.2.2-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for eigh-0.2.2-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 f3af526e7ba2698bcdf2b3db39c7fd1af00c51a467b52e36b51bebfd4a7f33c9
MD5 133c11d0156009a434bf0efd5756f533
BLAKE2b-256 f8617301064b164a199823eba9dce58047fb72ef8e56a0877821e470cf9e29b9

See more details on using hashes here.

File details

Details for the file eigh-0.2.2-cp311-cp311-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for eigh-0.2.2-cp311-cp311-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 ac98a1503f6ee6875d2c2451ef4ed9c38678cfebf55b62651c70e0d2eb86b158
MD5 620a28ca2e85ac3c9a00e76ba9c728da
BLAKE2b-256 0809bc33d24711add121a97f13b1c6c6332ea67aeed3721eebc45a687269a4f9

See more details on using hashes here.

File details

Details for the file eigh-0.2.2-cp310-cp310-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for eigh-0.2.2-cp310-cp310-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 8c79741f59863594fe9d3f1f0cbec288c996e5c7254721c5be694b5f9af50596
MD5 386f62d9f0b9ca9189203ed566e2844b
BLAKE2b-256 d025e35169563f79ec7dbefadb94cefc5fdb10fcd5053d392a5df76bd51aa829

See more details on using hashes here.

File details

Details for the file eigh-0.2.2-cp310-cp310-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for eigh-0.2.2-cp310-cp310-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 c84d82d2c8a0f208aa0cb0e35e429190067644605df8bbe7b88a3f469c3b2b9b
MD5 603448e1e045373a6c950892a3404247
BLAKE2b-256 399c96550df4b5067f962db0168d887d4676c1b73aa0d6894828b4a53d95b229

See more details on using hashes here.

File details

Details for the file eigh-0.2.2-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for eigh-0.2.2-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 de2e87eff73773018f5e375f9f916d864cc93846c5c0138b4ef8bc25dc9d0f60
MD5 2c8afeee83d526fa0d2a97574d67e7ae
BLAKE2b-256 0cb28ca3d8a80acb89919591410807eb4d22ea6b2589d37452df91955622a1b6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page