Skip to main content

Differentiable eigenvalue decomposition with JAX (CPU/GPU)

Project description

Differentiable Generalized Eigenvalue Decomposition

Tests

Eigh Logo

Standalone implementation of differentiable eigenvalue decomposition with CPU (LAPACK) and GPU (cuSOLVER) backends. Extracted from pyscfad.

Wheels on PyPI: https://pypi.org/project/eigh/ — Linux (manylinux_2_28, x86_64) and macOS (x86_64, arm64), Python 3.10–3.12. GPU path (cuSOLVER) is tested locally; CI runs CPU tests only.

Windows: no prebuilt wheel. The pure-JAX solvers in src/jax/ (e.g. safe_generalized_eigh, subspace_generalized_eigh, stable_generalized_eigh) work out-of-the-box — pip install jax numpy scipy and import directly from that module. The fast LAPACK/cuSOLVER-backed eigh / eigh_gen kernels (and therefore stable_eigh_pyscfad / stable_eigh_gen_pyscfad) require building from source against a local BLAS/LAPACK.

Features

  • Generalized Problems: A @ V = B @ V @ diag(W), etc.
  • JAX Integrated: Full support for jit, vmap, grad, and jvp.
  • High Performance: Optimized LAPACK (CPU) and cuSOLVER (GPU) kernels.
  • Precision: float32/64 and complex64/128.
  • Degeneracy Handling: Configurable deg_thresh for stable gradients.

Installation & Quick Start

# Install from source
pip install .

# For GPU support in this environment
pip install .[cuda-local]

Usage Example

import jax
import jax.numpy as jnp
# Gen. eigensolver from PySCFAD
from eigh import eigh

jax.config.update("jax_enable_x64", True)
A = jnp.array([[2., 1.], [1., 2.]])
w, v = eigh(A) # Standard
grad = jax.grad(lambda A: eigh(A)[0].sum())(A) # Differentiable

Benchmarks

Forward/backward scaling vs. matrix size, and gradient stability as eigenvalues approach degeneracy — for the JAX eigensolvers in src/jax/. See benchmarks/suite/ for the scripts.

Forward-pass scaling Backward-pass (gradient) scaling

API Reference

  • eigh(a, b=None, *, lower=True, eigvals_only=False, type=1, deg_thresh=1e-9) Scipy-compatible interface. type supports 1: A@v=B@v@λ, 2: A@B@v=v@λ, 3: B@A@v=v@λ.
  • eigh_gen(a, b, *, lower=True, itype=1, deg_thresh=1e-9) Lower-level generalized solver.

Degenerate Eigenvalues & Gradients

Individual eigenvalue gradients are ill-defined for degenerate (repeated) eigenvalues. However, symmetric functions (like sum, var, trace) have stable gradients. The deg_thresh parameter (default 1e-9) masks divisions by near-zero gaps to maintain stability.

JAX Eigensolvers (src/jax/)

A collection of differentiable generalized eigensolvers with different strategies for handling degenerate eigenvalues in reverse-mode gradients. Useful for training pipelines where degeneracies are common.

Recommended: stable_eigh_pyscfad / stable_eigh_gen_pyscfad from generalized_eigensolver_pyscfad.py. These are the solvers promoted by this package — they wrap the fast LAPACK/cuSOLVER-backed eigh / eigh_gen kernels with a Lorentzian-broadened custom VJP for stable gradients at (near-)degeneracies. The other solvers below are provided for comparison and experimentation.

Solver File Strategy Gradient-safe at degeneracies
stable_eigh_pyscfad / stable_eigh_gen_pyscfad generalized_eigensolver_pyscfad.py Wraps eigh/eigh_gen (LAPACK/cuSOLVER) with Lorentzian-broadened VJP [2] Yes — recommended
standard_eig generalized_eigensolver.py scipy.linalg.eigh (CPU, non-diff) N/A
jax_eig generalized_eigensolver.py Plain Cholesky + jnp.linalg.eigh No
generalized_eigh generalized_eigensolver.py Symmetrized Cholesky with SPD shift No (standard VJP)
degen_eigh generalized_eigensolver.py Custom VJP: mask degenerate F_ij via thresholding [1,3] Yes, for symmetric-subspace losses
safe_generalized_eigh generalized_eigensolver.py Cholesky + degen_eigh Yes
subspace_eigh generalized_eigensolver.py Custom VJP: Lorentzian broadening F/(F²+ε²) [2] Yes
subspace_generalized_eigh generalized_eigensolver.py Symmetry-breaking perturbation + subspace_eigh [2,4] Yes
stable_eigh / stable_generalized_eigh generalized_eigensolver_stable.py Pure-JAX Cholesky + Lorentzian-broadened VJP [2] Yes

References (verified)

Development & Testing

  • Requirements: CMake 3.18+, C++17, JAX, NumPy, LAPACK/CUDA.
  • Tests:
    pytest tests/test_eigh.py     # Core functionality
    pytest tests/test_eigh_gen.py # Generalized itypes
    pytest tests/test_eigh_jit.py # JIT & vmap
    
  • GPU Setup:
    source setup_gpu_env_clean.sh
    ./run_gpu.sh python example_simple.py
    

License & Citation

Apache License 2.0. If used in research, please cite:

@software{pyscfad,
  author = {Zhang, Xing},
  title = {PySCFad: Automatic Differentiation for PySCF},
  url = {https://github.com/fishjojo/pyscfad},
  year = {2021-2025}
}

@software{sokolov2026eigh,
  title={Eigh: Differentiable eigenvalue decomposition with jax (cpu/gpu)},
  author={Sokolov, Igor},
  url={https://github.com/Brogis1/eigh},
  year={2026}
}

@article{sokolov2026xc,
  title = {Quantum-enhanced neural exchange-correlation functionals},
  author = {Sokolov, Igor O. and Both, Gert-Jan and Bochevarov, Art D. and Dub, Pavel A. and Levine, Daniel S. and Brown, Christopher T. and Acheche, Shaheen and Barkoutsos, Panagiotis Kl. and Elfving, Vincent E.},
  journal = {Phys. Rev. A},
  volume = {113},
  issue = {1},
  pages = {012427},
  numpages = {24},
  year = {2026},
  month = {Jan},
  publisher = {American Physical Society},
  doi = {10.1103/m51l-fys2},
  url = {https://link.aps.org/doi/10.1103/m51l-fys2}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

eigh-0.2.0.tar.gz (3.5 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

eigh-0.2.0-cp313-cp313-manylinux_2_28_x86_64.whl (12.0 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.28+ x86-64

eigh-0.2.0-cp313-cp313-macosx_11_0_arm64.whl (60.1 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

eigh-0.2.0-cp313-cp313-macosx_10_15_x86_64.whl (63.7 kB view details)

Uploaded CPython 3.13macOS 10.15+ x86-64

eigh-0.2.0-cp312-cp312-manylinux_2_28_x86_64.whl (12.0 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.28+ x86-64

eigh-0.2.0-cp312-cp312-macosx_11_0_arm64.whl (60.1 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

eigh-0.2.0-cp312-cp312-macosx_10_15_x86_64.whl (63.7 kB view details)

Uploaded CPython 3.12macOS 10.15+ x86-64

eigh-0.2.0-cp311-cp311-manylinux_2_28_x86_64.whl (12.0 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.28+ x86-64

eigh-0.2.0-cp311-cp311-macosx_11_0_arm64.whl (60.5 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

eigh-0.2.0-cp311-cp311-macosx_10_15_x86_64.whl (64.0 kB view details)

Uploaded CPython 3.11macOS 10.15+ x86-64

eigh-0.2.0-cp310-cp310-manylinux_2_28_x86_64.whl (12.0 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.28+ x86-64

eigh-0.2.0-cp310-cp310-macosx_11_0_arm64.whl (60.2 kB view details)

Uploaded CPython 3.10macOS 11.0+ ARM64

eigh-0.2.0-cp310-cp310-macosx_10_15_x86_64.whl (63.7 kB view details)

Uploaded CPython 3.10macOS 10.15+ x86-64

File details

Details for the file eigh-0.2.0.tar.gz.

File metadata

  • Download URL: eigh-0.2.0.tar.gz
  • Upload date:
  • Size: 3.5 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for eigh-0.2.0.tar.gz
Algorithm Hash digest
SHA256 c1d1802456583d2ced3d505d5ec9b43b9683d069f60ff7b760dfbea6fbfff593
MD5 36048ace17f582faa93809b52291d695
BLAKE2b-256 b6889e7dc73a6d7818e1598bb5904ae3b0c4c11405810ce2a4d28d6f7e192feb

See more details on using hashes here.

File details

Details for the file eigh-0.2.0-cp313-cp313-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for eigh-0.2.0-cp313-cp313-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 9e41d004bdddfc5211d4b3aa2a0ff28d3e1f2537e5b57fc60dd23ac957363adb
MD5 f1e296bbf73416ebd33be5bdcdaad4e8
BLAKE2b-256 55c7274d172955f6ff1aa662869c9875f556773f2dc204bf4a54775687b6a2e7

See more details on using hashes here.

File details

Details for the file eigh-0.2.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for eigh-0.2.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 11c7de8b535f76a92673313ed77cf3535331223d056696aaa11dbc71711db040
MD5 f405c76316c47dafd69f1a6fb7d08374
BLAKE2b-256 404c565d34e551ad0b1e9dc7420a29e0f9e414c5d9e9a4e639c378e4b4456cf8

See more details on using hashes here.

File details

Details for the file eigh-0.2.0-cp313-cp313-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for eigh-0.2.0-cp313-cp313-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 9ca9cae143cc6d2fe71c00b6d48693fbf8dfdf39d7a2daf0e5600fba63b632ac
MD5 47679586df3dafe6ad63b59664813063
BLAKE2b-256 70889b6d9a5e444a03269c5006792c1902c675af8d7bb9218627eadcfd667d46

See more details on using hashes here.

File details

Details for the file eigh-0.2.0-cp312-cp312-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for eigh-0.2.0-cp312-cp312-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 3615f4727101dd24cde75c020f6554ff816081addcbbb0da3912123ec286a0f5
MD5 2ea90607f21f919b4ec869ab914b1ef7
BLAKE2b-256 8bac6fe1f8d38e5e92caf90601365e43e4e245083e86f70fda7442315c829a29

See more details on using hashes here.

File details

Details for the file eigh-0.2.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for eigh-0.2.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 29a17735f1248b93df2b722f698245370cbdf4f1bd1147599dda6d3cab4e08e0
MD5 83054f68b413559bd3469683ab6d5c17
BLAKE2b-256 5be5130850242cf6ce16c73df8680258c2e079dd65db0ee56e06bd0fcf6d09e1

See more details on using hashes here.

File details

Details for the file eigh-0.2.0-cp312-cp312-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for eigh-0.2.0-cp312-cp312-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 ab0f32d780e5c72bd621ee3dc67d9562d4ea46a3f86a848039296970503c1391
MD5 d151ab205a45bc20d2a453b53896cb76
BLAKE2b-256 a98c3d2719f045ef39b76dc564266697fda4c68feedf63f68d9d6297c782d412

See more details on using hashes here.

File details

Details for the file eigh-0.2.0-cp311-cp311-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for eigh-0.2.0-cp311-cp311-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 6029ecb06305e5c2020a4af544f11afe91ba2a5526bf9c09bf4c47bdf4ca5d2f
MD5 49dfb025fc294e215686d5a0f0143d37
BLAKE2b-256 04803526e318b7b04e6f435750210b14b1aedbf1ea1906c6e1f9a07c48bcea0d

See more details on using hashes here.

File details

Details for the file eigh-0.2.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for eigh-0.2.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 74b8a084036a2d8a4c69d8fa3f3767adea66a655188e79c76da4bf52975518e2
MD5 0f0054c021c025a8e44533f5cdab8a92
BLAKE2b-256 1f5206929aadbf689a0443762f77b30230abd75bd0b734610391319d80a3c775

See more details on using hashes here.

File details

Details for the file eigh-0.2.0-cp311-cp311-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for eigh-0.2.0-cp311-cp311-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 f15b514145638f2394015fbb248c36cf889858f8ecdedfe2441bd880dc0f17f8
MD5 0344597a682bceb1589b0b072303e46a
BLAKE2b-256 e7af823ac9819c113aeb31530a99620298ab2b87c560880d9e0e535f99063844

See more details on using hashes here.

File details

Details for the file eigh-0.2.0-cp310-cp310-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for eigh-0.2.0-cp310-cp310-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 1cf566284dfad0da262be153a70d487f95e9a8cc2dcbd48cf65432ec4ba8b25d
MD5 5a83a1c4e724b93e4aa9ab157bc3e352
BLAKE2b-256 1465f40e4b2ce16f3bd131b0a495772b2f69ff2162acecf917feebfb5886da95

See more details on using hashes here.

File details

Details for the file eigh-0.2.0-cp310-cp310-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for eigh-0.2.0-cp310-cp310-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 084842c9afb479b335d08d0a17932989c73cff67a6e44194de369f9954a949d3
MD5 47f01ffe99ff233cba1ee3f6285140c7
BLAKE2b-256 a2f58f7512d591c682be955076d260bebb1ea4b5c5216bd4b74e3b698e227f96

See more details on using hashes here.

File details

Details for the file eigh-0.2.0-cp310-cp310-macosx_10_15_x86_64.whl.

File metadata

File hashes

Hashes for eigh-0.2.0-cp310-cp310-macosx_10_15_x86_64.whl
Algorithm Hash digest
SHA256 3694478392933a19d010d7e7b430e7a8fbc2caeca06f04dd382c508cc214f216
MD5 b595b9ab5784bec0914c1d9d4c2e1111
BLAKE2b-256 3f3808422cd46d6115adc1b2ddce714b36768165fdd2d6dd0c0c22a32cda29d5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page