Skip to main content

A package that provides a jax-based Hartree-Fock optimization solver for simple continuum models.

Project description

jax_hf — JAX Hartree–Fock on k‑grids

PyPI Python Wheel License Build Release

jax_hf provides two JAX-jitted solvers for the Hartree–Fock free-energy minimisation problem on 2D k-meshes:

  • Direct minimisation (primary): preconditioned Riemannian CG on Stiefel × capped simplex, eigen-free inner loop, Cayley retraction, one Fock build per iteration.
  • Reference SCF (baseline / fallback): standard Roothaan iteration with linear mixing.

Exchange and Hartree can both be included, and the exchange kernel may be layer-resolved. See examples/ for density-scan scripts on a bilayer graphene model.

v2.0.0 note: This release is a clean-slate rewrite. The entire public API has changed relative to the deprecated v1.x line (which was already a skeleton in v1.1.0). See MIGRATION.md for the migration guide.

Install

pip install jax-hf

Minimal example

import jax.numpy as jnp
import jax_hf

# Build a HartreeFockKernel: precomputes the FFT of the interaction kernel,
# the Hartree matrix, etc., ready for JIT.
kernel = jax_hf.HartreeFockKernel(
    weights=weights,          # (nk1, nk2) k-point weights
    hamiltonian=hamiltonian,  # (nk1, nk2, nb, nb) single-particle Hamiltonian
    coulomb_q=coulomb_q,      # (nk1, nk2, 1, 1) scalar or (nk1, nk2, nb, nb) layer-resolved
    T=0.1,
    include_hartree=False,    # set True for Hartree; also pass reference_density + hartree_matrix
    include_exchange=True,
)

# Solve (direct minimisation, default)
result = jax_hf.solve(kernel, P0=jnp.zeros_like(hamiltonian), n_electrons=N)
print(result.energy, result.converged, result.n_iter)
# result.density, result.fock, result.Q, result.p, result.mu, result.history

# Or use SCF as a fallback baseline
result_scf = jax_hf.solve_scf(kernel, P0=jnp.zeros_like(hamiltonian), n_electrons=N)

Config

Both solvers take a Config dataclass with sensible defaults:

jax_hf.SolverConfig(max_iter=200, tol_E=1e-7, max_step=0.6, project_fn=None, ...)
jax_hf.SCFConfig(max_iter=200, mixing=0.3, density_tol=1e-7, comm_tol=1e-6, ...)

project_fn lets you enforce symmetry constraints (spin, valley, time reversal, spatial) on the density and Fock at every iteration. See jax_hf.symmetry.make_project_fn.

Public API

Name Purpose
HartreeFockKernel Problem + precomputed arrays
solve (alias solve_direct_minimization), SolverConfig, SolveResult Primary solver
solve_scf, SCFConfig, SCFResult Reference SCF solver
build_fock, hf_energy, free_energy, occupation_entropy HF objective building blocks
solve_continuation, ContinuationResult, resample_kgrid Coarse → fine multigrid driver + k-grid resampler

Lower-level modules (jax_hf.utils, jax_hf.symmetry, jax_hf.linalg, jax_hf.fock) expose the individual pieces for users who need them.

Coarse → fine continuation

For large fine grids, solve_continuation runs a cheap coarse solve first and uses its density to seed the fine solve. The two stages can mix and match direct minimisation and SCF:

import jax_hf
from jax_hf import SCFConfig, SolverConfig

coarse = jax_hf.HartreeFockKernel(weights_c, h_c, Vq_c, T=0.1)
fine   = jax_hf.HartreeFockKernel(weights_f, h_f, Vq_f, T=0.1)

result = jax_hf.solve_continuation(
    coarse, fine, P0_coarse=jnp.zeros_like(h_c),
    n_electrons_coarse=N, n_electrons_fine=N,
    coarse_config=SCFConfig(max_iter=50, mixing=0.5),   # robust coarse
    fine_config=SolverConfig(max_iter=200, tol_E=1e-8), # fast fine
)
# result.coarse, result.fine (each a SolveResult or SCFResult)
# result.P0_fine (resampled coarse density used to seed the fine solve)

The driver is intentionally algorithm-agnostic: it resamples the coarse density onto the fine grid via resample_kgrid and hands off. Callers that need physics-aware seeding (reference-density interpolation, self-energy seeds, filling-consistent electron counts across grids) should construct both kernels themselves.

Examples

  • examples/multilayer_graphene_density_scan.py — PM/SVP density scan for bilayer graphene, direct minimisation, Fock only
  • examples/multilayer_graphene_density_scan_extended.py — adds spin-polarised and "SVP flipped" branches (4 total)
  • examples/multilayer_graphene_density_scan_hartree.py — same four branches with layer-resolved Coulomb and Hartree included
  • examples/multilayer_graphene_reference_scf_scan.py — SCF baseline scan for side-by-side comparison

Running tests

pytest tests/

The bilayer regression tests (tests/test_bilayer_regression.py) require contimod and contimod_graphene and will be skipped otherwise.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_hf-2.1.0.tar.gz (117.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_hf-2.1.0-py3-none-any.whl (32.4 kB view details)

Uploaded Python 3

File details

Details for the file jax_hf-2.1.0.tar.gz.

File metadata

  • Download URL: jax_hf-2.1.0.tar.gz
  • Upload date:
  • Size: 117.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jax_hf-2.1.0.tar.gz
Algorithm Hash digest
SHA256 3452d6163758756869aa3a02355d5bca056bdbf277684bb88ddbc561431860ed
MD5 25165d28ebe026d015dd909802e40dcb
BLAKE2b-256 22b6a1db4e02641b0465d2a9048856b8c1fe1a274795e56cf203025a86890aed

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_hf-2.1.0.tar.gz:

Publisher: release.yml on skilledwolf/jax_hf

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_hf-2.1.0-py3-none-any.whl.

File metadata

  • Download URL: jax_hf-2.1.0-py3-none-any.whl
  • Upload date:
  • Size: 32.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jax_hf-2.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3cbdb67be26a75e926699907d7f5d47263e85324815765972b3c11311c291a41
MD5 404cec7ee12c5f04809dd7c6ac6f993b
BLAKE2b-256 6b792483763b77a6305489f60d6f904b74a41d9435fe25b2e14c65416ce8e6d5

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_hf-2.1.0-py3-none-any.whl:

Publisher: release.yml on skilledwolf/jax_hf

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page