Skip to main content

Differentiable eigenvalue decomposition with JAX — CUDA 12.0 (GPU) build

Project description

Differentiable Generalized Eigenvalue Decomposition

Tests

Eigh Logo

Standalone implementation of differentiable eigenvalue decomposition with CPU (LAPACK) and GPU (cuSOLVER) backends. Extracted from pyscfad.

CPU and GPU wheels on PyPI for Linux and macOS (Apple Silicon), Python 3.10–3.13, JAX 0.5–0.10+. See Installation and Compatibility.

New

  • Core code rewritten to be able to run on older cluster with JAX 0.4.XX for instance (most likely on GPU clusters).
  • Builds for CUDA but I recommend just building from source, fast and easy with this package (it will work for your specific JAX version).

Features

  • Generalized Problems: A @ V = B @ V @ diag(W), etc.
  • JAX Integrated: Full support for jit, vmap, grad, and jvp.
  • High Performance: Optimized LAPACK (CPU) and cuSOLVER (GPU) kernels.
  • Precision: float32/64 and complex64/128.
  • Degeneracy Handling: Configurable deg_thresh for stable gradients.

Installation & Quick Start

CPU

pip install eigh

Prebuilt CPU-only wheels — Linux (x86_64) and macOS (Apple Silicon), Python 3.10–3.13, JAX 0.5+.

GPU - Build from source (Recommended)

Make sure first that you have JAX installed that runs fine on your GPU.

Build from source for a jaxlib / CUDA / glibc combination the wheels don't cover. The main case is an environment pinned to e.g., jaxlib 0.4.29 (the prebuilt wheels require jaxlib ≥ 0.5 — the FFI binary ABI changed at 0.5, so a 0.5 wheel can't run on 0.4.x and vice-versa). The source builds against whatever jaxlib is in your env (0.4.29 or 0.5–0.10+), CPU or GPU:

git clone https://github.com/Brogis1/eigh && cd eigh
pip install "scikit-build-core>=0.8" "nanobind>=1.0.0" cmake ninja
pip install . --no-build-isolation --no-deps
  • --no-build-isolation compiles against the jaxlib already in your env.
  • --no-deps keeps your pinned jax/jaxlib (essential on 0.4.29 — otherwise pip would upgrade it to ≥0.5).
  • For GPU, have nvcc on PATH (module load cuda/12.x); look for CUDA support enabled in the build log (CUDA not foundnvcc not on PATH). Plain jaxlib (no CUDA) yields a CPU-only build.

Full details and the why — FFI ABI, pinned-jaxlib clusters, nvcc paths, GPU verification — are in docs/TECHNICAL_NOTES.md.

GPU (CUDA 12, Linux x86_64)

You can try this and may get lucky if it happens that JAX and other libraries match. I strongly recommend to build from source (see GPU - Build from source). Pick the package matching your cluster's CUDA version:

pip install eigh-cuda120   # CUDA 12.0+ (works through 12.8+); the safe default
pip install eigh-cuda128   # CUDA 12.8+ (newer toolchain / glibc 2.34)

Both bundle the cuSOLVER kernel + NVIDIA CUDA runtime libs; import eigh auto-detects the GPU. They are separate packages from this same repo — import eigh is identical. See Compatibility.

Usage Example

import jax
import jax.numpy as jnp
# Gen. eigensolver from PySCFAD
from eigh import eigh, eigh_gen

jax.config.update("jax_enable_x64", True)
# Eigenvalue problem
A = jnp.array([[2., 1.], [1., 2.]])
B = jnp.array([[1., 1], [0.5, 1.]])
w1, v1 = eigh(A)
w2, v2 = eigh_gen(A, B)

# With gradients
grad1 = jax.grad(lambda A: eigh(A)[0].sum())(A)
grad2 = jax.grad(lambda A: eigh_gen(A, B)[0].sum())(A)
print("Eigenvalues:", w1, w2)
print("Eigenvectors:", v1, v2)
print("Gradients computed:", grad1.shape, grad2.shape)

Benchmarks

Forward/backward scaling vs. matrix size, and gradient stability as eigenvalues approach degeneracy — for the JAX eigensolvers in src/jax/. See benchmarks/suite/ for the scripts.

Forward-pass scaling Backward-pass (gradient) scaling

API Reference

  • eigh(a, b=None, *, lower=True, eigvals_only=False, type=1, deg_thresh=1e-9) Scipy-compatible interface. type supports 1: A@v=B@v@λ, 2: A@B@v=v@λ, 3: B@A@v=v@λ.
  • eigh_gen(a, b, *, lower=True, itype=1, deg_thresh=1e-9) Lower-level generalized solver.

Degenerate Eigenvalues & Gradients

Individual eigenvalue gradients are ill-defined for degenerate (repeated) eigenvalues. However, symmetric functions (like sum, var, trace) have stable gradients. The deg_thresh parameter (default 1e-9) masks divisions by near-zero gaps to maintain stability.

JAX Eigensolvers

A collection of differentiable generalized eigensolvers with different strategies for handling degenerate eigenvalues in reverse-mode gradients. Useful for training pipelines where degeneracies are common.

If you just want a working solver, use stable_eigh_pyscfad / stable_eigh_gen_pyscfad from generalized_eigensolver_pyscfad.py. They wrap the fast LAPACK/cuSOLVER kernels with a Lorentzian-broadened custom VJP, so gradients stay stable when eigenvalues are (nearly) degenerate.

On Windows, or if you cannot build the C++ kernels, use stable_generalized_eigh from generalized_eigensolver_stable.py instead — same gradient treatment, pure JAX.

The remaining solvers below are kept for benchmarking and for reproducing prior work; they are not recommended as defaults.

Recommended

Solver File Strategy
stable_eigh_pyscfad / stable_eigh_gen_pyscfad generalized_eigensolver_pyscfad.py LAPACK/cuSOLVER kernels + Lorentzian-broadened VJP [2]
stable_eigh / stable_generalized_eigh (pure-JAX) generalized_eigensolver_stable.py Pure-JAX Cholesky + Lorentzian-broadened VJP [2]

Alternative stable solvers

Solver File Strategy Gradient notes
subspace_eigh generalized_eigensolver.py Custom VJP: Lorentzian broadening F/(F²+ε²) [2] Stable
subspace_generalized_eigh generalized_eigensolver.py Symmetry-breaking perturbation + subspace_eigh [2,4] Stable
degen_eigh generalized_eigensolver.py Custom VJP: mask degenerate F_ij by threshold [1,3] Stable only for symmetric-subspace losses
safe_generalized_eigh generalized_eigensolver.py Cholesky + degen_eigh Inherits degen_eigh caveat

Baselines (not gradient-safe at degeneracies)

Solver File Strategy
standard_eig generalized_eigensolver.py scipy.linalg.eigh — non-differentiable reference
jax_eig generalized_eigensolver.py Plain Cholesky + jnp.linalg.eigh, default VJP
generalized_eigh generalized_eigensolver.py Symmetrized Cholesky with SPD shift, default VJP

Compatibility

  • Python: 3.10–3.13. JAX: 0.5 → 0.10+ for the prebuilt wheels; jax 0.4.29 via source build (its FFI ABI differs from 0.5+, so it needs its own build — see Build from source).
  • CPU wheels: Linux x86_64 (manylinux_2_28, bundled OpenBLAS) and macOS arm64.
  • GPU wheels: Linux x86_64, CUDA 12 — eigh-cuda120 (CUDA 12.0+, glibc 2.17) and eigh-cuda128 (CUDA 12.8+, glibc 2.34).
  • Windows: no compiled wheel — use the pure-JAX solvers in src/jax/, or build from source.

Full detail — the FFI binary-ABI rules, why GPU ships as separate packages, the abi3 wheel matrix, HPC/cluster notes, and how to build on a pinned old jaxlib — is in docs/TECHNICAL_NOTES.md.

References

Development & Testing

  • Requirements: CMake 3.18+, C++17, JAX, NumPy, LAPACK/CUDA.
  • Tests:
    pytest tests/test_eigh.py     # Core functionality
    pytest tests/test_eigh_gen.py # Generalized itypes
    pytest tests/test_eigh_jit.py # JIT & vmap
    
  • GPU Setup:
    source setup_gpu_env_clean.sh
    ./run_gpu.sh python example_simple.py
    

License & Citation

Apache License 2.0. If used in research, please cite:

@software{sokolov2026eigh,
  author={Sokolov, Igor},
  title={Eigh: Differentiable eigenvalue decomposition with jax (cpu/gpu)},
  url={https://github.com/Brogis1/eigh},
  year={2026}
}

@software{pyscfad,
  author = {Zhang, Xing},
  title = {PySCFad: Automatic Differentiation for PySCF},
  url = {https://github.com/fishjojo/pyscfad},
  year = {2021-2025}
}

@article{10.1063/5.0118200,
    author = {Zhang, Xing and Chan, Garnet Kin-Lic},
    title = {Differentiable quantum chemistry with PySCF for molecules and materials at the mean-field level and beyond},
    journal = {The Journal of Chemical Physics},
    volume = {157},
    number = {20},
    pages = {204801},
    year = {2022},
    month = {11},
    issn = {0021-9606},
    doi = {10.1063/5.0118200},
    url = {https://doi.org/10.1063/5.0118200},
}

@article{sokolov2026xc,
  title = {Quantum-enhanced neural exchange-correlation functionals},
  author = {Sokolov, Igor O. and Both, Gert-Jan and Bochevarov, Art D. and Dub, Pavel A. and Levine, Daniel S. and Brown, Christopher T. and Acheche, Shaheen and Barkoutsos, Panagiotis Kl. and Elfving, Vincent E.},
  journal = {Phys. Rev. A},
  volume = {113},
  issue = {1},
  pages = {012427},
  numpages = {24},
  year = {2026},
  month = {Jan},
  publisher = {American Physical Society},
  doi = {10.1103/m51l-fys2},
  url = {https://link.aps.org/doi/10.1103/m51l-fys2}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

eigh_cuda120-0.4.1-cp312-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (10.5 MB view details)

Uploaded CPython 3.12+manylinux: glibc 2.17+ x86-64

eigh_cuda120-0.4.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (10.5 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

eigh_cuda120-0.4.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (10.5 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

File details

Details for the file eigh_cuda120-0.4.1-cp312-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.

File metadata

File hashes

Hashes for eigh_cuda120-0.4.1-cp312-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
Algorithm Hash digest
SHA256 a03ff4ca01fe6b46b469106533284fcbf0c1bc5313da64604e469d09e30d8c1d
MD5 219bd2e26ca3c2417aca425fb6e9dd9f
BLAKE2b-256 c3f94d108af398f50c58a7afcd31486cd5d7da2f973c3e8b7e680f62e7d491e1

See more details on using hashes here.

File details

Details for the file eigh_cuda120-0.4.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.

File metadata

File hashes

Hashes for eigh_cuda120-0.4.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
Algorithm Hash digest
SHA256 21f7bedb909d6d78de8615081868958831872c1e35e8b436d72a134f43c42064
MD5 ecacd4b76bf94fda4d914bfc7458ebf5
BLAKE2b-256 ab13713fe7c4fe50dcc5ee2c2b0d052644e4f6a12ccccca40f4c33505c1d5eb9

See more details on using hashes here.

File details

Details for the file eigh_cuda120-0.4.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.

File metadata

File hashes

Hashes for eigh_cuda120-0.4.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
Algorithm Hash digest
SHA256 d25d36018d721cc14f0dd5ff3390ece387ce420a66a13cdb3e32968514e6fff0
MD5 f86b161fac2e965ff2f85b24c7654a05
BLAKE2b-256 b6d4698c64fef9bc1f062233a7b123419ab4c993036739af9c81f8040b3f8bed

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page