Skip to main content

End-to-end differentiable JAX implementation of VMEC2000 for fixed and free-boundary equilibria.

Project description

vmec-jax

PyPI version Conda Version Python License CI Coverage Docs PyPI downloads

End-to-end differentiable JAX implementation of VMEC2000 for fixed-boundary and free-boundary ideal-MHD equilibria.

Install

From PyPI

pip install vmec-jax

QI optimization uses booz_xform_jax for the differentiable Boozer transform:

pip install "vmec-jax[qi]"

From conda-forge

vmec-jax can be installed as a conda package from conda-forge into a particular project with Pixi

pixi add vmec-jax

or into a conda environment with conda

conda install --channel conda-forge vmec-jax

From source

Developer (editable) install:

git clone https://github.com/uwplasma/vmec_jax
pip install -e "vmec_jax[qi]"

Usage

Run the solver (VMEC2000-style CLI):

vmec_jax input.nfp4_QH_warm_start        # → wout_nfp4_QH_warm_start.nc

Generate diagnostic plots from any wout_*.nc (four-panel output, replicates vmecPlot2.py):

vmec_jax --plot wout_nfp4_QH_warm_start.nc           # saves in same directory
vmec_jax --plot wout_nfp4_QH_warm_start.nc --outdir figures/

From Python:

import vmec_jax as vj

# Run a fixed-boundary solve
run = vj.run_fixed_boundary("input.nfp4_QH_warm_start")

# Run a free-boundary solve
freeb = vj.run_free_boundary("input.cth_like_free_bdy_lasym_small")

# Plot any wout file (produces *_VMECparams.pdf, *_poloidal_plot.png, *_VMECsurfaces.pdf, *_VMEC_3Dplot.png)
vj.plot_wout("wout_nfp4_QH_warm_start.nc", outdir="figures/")

Choosing CPU or GPU

vmec_jax follows the JAX backend you select. If you installed CPU-only JAX, runs use CPU. If you installed GPU-enabled JAX and select a GPU backend, runs use GPU; vmec_jax does not silently force those runs back to CPU.

# Check what JAX will use.
python -c "import jax; print(jax.default_backend()); print(jax.devices())"

# Force CPU for one command.
JAX_PLATFORMS=cpu vmec_jax input.nfp4_QH_warm_start

# Force an accelerator backend after installing GPU-enabled JAX.
JAX_PLATFORM_NAME=gpu vmec_jax input.nfp4_QH_warm_start

# For NVIDIA CUDA specifically, this is also valid.
JAX_PLATFORMS=cuda vmec_jax input.nfp4_QH_warm_start

From Python, leave solver_device unset to inherit JAX's default backend, or pass solver_device="cpu" / solver_device="gpu" explicitly:

import vmec_jax as vj

run_gpu = vj.run_fixed_boundary("input.nfp4_QH_warm_start", solver_device="gpu")
run_cpu = vj.run_fixed_boundary("input.nfp4_QH_warm_start", solver_device="cpu")

For GPU runs, vmec_jax defaults XLA_PYTHON_CLIENT_PREALLOCATE=false before JAX import so the allocator grows on demand. This avoids GPU memory contention between optimization workers and was faster in the exact-Jacobian GPU profile. Set XLA_PYTHON_CLIENT_PREALLOCATE=true before import if you explicitly want JAX's default preallocation behavior.

vmec_jax enables JAX's persistent compilation cache automatically for accelerator-selected runs, including runs where CUDA_VISIBLE_DEVICES or the ROCm equivalents expose an accelerator before import. CPU cache use is explicit opt-in because XLA:CPU AOT cache hits can emit host-feature mismatch errors on some JAX versions. Set VMEC_JAX_COMPILATION_CACHE=1 to enable the default cache for CPU runs, set VMEC_JAX_COMPILATION_CACHE=0 to disable it, or set VMEC_JAX_COMPILATION_CACHE_DIR=/path/to/cache to choose a custom location. The default cache path is scoped by machine, CPU features, Python version, and JAX/JAXLIB versions.

For the current small/medium fixed-boundary examples, CPU is often faster after JIT warmup. GPU support is production-enabled and useful to profile, but the exact optimizer defaults accepted-point Jacobians to the discrete-adjoint tape path on both CPU and GPU. The scan exact path is an explicit diagnostic override via VMEC_JAX_OPT_EXACT_PATH=scan; relaxed trial residuals use the scan forward path by default. See the performance guide for current CPU/GPU timings and profiling commands.

Showcase (single-grid)

All figures below use the same single-grid run settings: NS_ARRAY=151, NITER_ARRAY=5000, FTOL_ARRAY=1e-14, NSTEP=500.

ITERModel cross-section (VMEC2000 vs vmec_jax) LandremanPaul2021_QA_lowres cross-section (VMEC2000 vs vmec_jax)
ITERModel iota (VMEC2000 vs vmec_jax) LandremanPaul2021_QA_lowres iota (VMEC2000 vs vmec_jax)
ITERModel 3D LCFS LandremanPaul2021_QA_lowres 3D LCFS
ITERModel |B| on LCFS LandremanPaul2021_QA_lowres |B| on LCFS

Cold vs warm runtime: the cold bar includes XLA JIT compilation on the first call (one-time cost per process); the warm bar is the steady-state solve time for subsequent calls in the same process. VMEC2000 has no compilation overhead, so it is always effectively cold. vmec_jax uses JAX's persistent compilation cache automatically for accelerator-selected runs under ~/.cache/vmec_jax/jax_cache/<machine-fingerprint>. CPU cache use is opt-in with VMEC_JAX_COMPILATION_CACHE=1 to avoid XLA:CPU AOT host-feature mismatch warnings on some JAX versions.

Best Stellarator-Symmetric Optimizations

The fixed-boundary optimization examples solve VMEC equilibria and differentiate the objective with the exact discrete-adjoint/tape path. The README only shows one current best LASYM = F result for each target; the full CPU/GPU policy matrix, LASYM panels, finite-beta examples, QI constraint sweep, and all tables live in the optimization guide and optimization sweep results.

Each row below shows the original deck LCFS before any max_mode=1 optimization work, the final LCFS, per-stage objective history, and the final outer-surface |B| in Boozer coordinates computed with booz_xform_jax. This sweep uses NFP=2 seeds for QA/QP/QI and the standard bundled NFP=4 warm start for QH. The current objective priority is primary symmetry/QI quality and rotational-transform control. QA follows the reference omnigenity QA deck with aspect ratio near 5 and signed mean iota target 0.42; QH/QP/QI also use abs(mean_iota) >= 0.41; QI now uses a higher aspect-ratio target of 10 to make precise QI with acceptable mirror ratio and elongation less overconstrained. LgradB remains available as an optional script-level term, but it is not active in the default README examples or best-row selection.

The QP and QI rows both start from the bundled NFP=2 QI seed. QP is a quasi-poloidal-symmetry target using that same input deck; the current best QI row uses the dedicated mirror-aware QI_optimization.py lane at max_mode=3 without a QP preseed. The bundled NFP=2 seed is projected to each active max_mode, so max_mode=1 zeroes the seed's mode-2 boundary harmonics before optimizing. For QI, the listed wall time includes all repeated stages using the same constrained least-squares residual definition.

Target Backend Policy max_mode ESS QP preseed Final J QI legacy Mirror Elong. Aspect Iota Wall time
QA CPU continuation 3 yes 2.33e-04 5.000 0.4200 6.3 min
QH CPU continuation 3 yes 9.68e-03 4.999 -1.6595 4.0 min
QP CPU continuation 3 no 6.76e-02 5.019 -0.6255 3.7 min
QI CPU qi_default 3 yes no 1.17e-02 3.09e-04 0.225 6.43 9.999 -0.5043 10.1 min

Recreate the four displayed runs:

PYTHONPATH=. JAX_PLATFORMS=cpu python examples/optimization/generate_qs_ess_sweep.py --backend-label cpu --solver-device cpu --policy continuation --problems qa --modes 3 --ess on
PYTHONPATH=. JAX_PLATFORMS=cpu python examples/optimization/generate_qs_ess_sweep.py --backend-label cpu --solver-device cpu --policy continuation --problems qh --modes 3 --ess on
PYTHONPATH=. JAX_PLATFORMS=cpu python examples/optimization/generate_qs_ess_sweep.py --backend-label cpu --solver-device cpu --policy continuation --problems qp --modes 3 --ess off
PYTHONPATH=. JAX_PLATFORMS=cpu VMEC_JAX_QI_RUN_CASE=nfp2_qi python examples/optimization/QI_optimization.py
PYTHONPATH=. python examples/optimization/render_qi_constrained_sweep.py

For QI seed-robustness probes, set VMEC_JAX_QI_RUN_CASE=qi_stel_seed_3127 when running examples/optimization/QI_optimization.py, or change the top-level RUN_CASE to nfp1_qi, nfp2_qi, qi_stel_seed_3127, nfp4_qh_warm_to_qi, or a new QI_CASES entry for another VMEC input deck. The NFP=4 QH-warm case is currently a diagnostic stress test: it exercises the same machinery, but it is not yet a validated route to a precise NFP=4 QI state. Before promoting such a result, run examples/optimization/audit_qi_seed_suitability.py --quick and check the legacy QI, smooth QI, mirror ratio, elongation, iota, and Boozer |B| line-contour diagnostics. For the qi_stel_seed_3127 far-seed lane, use the same gates as the optimization case: --smooth-qi-max 5e-3 --legacy-qi-max 2e-3. Use the prefine manifest path for reviewed expensive probes rather than launching ad hoc far-seed jobs.

The input.QI_stel_seed_3127 robustness lane is intentionally harder than the default NFP=2 QI seed. Purely local boundary moves still get trapped, but the current QI_optimization.py case now includes a deterministic same-NFP reference-family preconditioner: it interpolates the seed boundary toward the bundled NFP=3 QI reference, audits each candidate with the independent smooth/legacy QI, mirror, elongation, aspect, and iota gates, and then starts local QI cleanup from the lowest-mirror accepted non-endpoint candidate when one exists. That candidate is recorded as the accepted baseline, so later cleanup stages cannot replace it unless exact diagnostics improve. For this far-seed case the legacy Goodman-style QI gate is 2e-3, while the smooth differentiable proxy gate is 5e-3 because it is the optimization surrogate and is more conservative on the six-surface audit. Guarded local cleanup can now use anisotropic boundary stages, for example unlocking max_m=1, max_n=4 before the full max_m=max_n=4 boundary. The script promotes such stages only when independent exact diagnostics improve, so a mirror-heavy local solve cannot replace a precise-QI baseline if it damages legacy QI. The diagnostic below scans two boundary coefficients around the raw seed and shows why this larger global-to-local move is needed.

Recreate that landscape plot:

PYTHONPATH=. JAX_PLATFORMS=cpu python tools/diagnostics/qi_landscape_scan.py \
  --input examples/data/input.QI_stel_seed_3127 \
  --output-dir results/diagnostics/qi_landscape_seed3127 \
  --max-mode 3 --min-vmec-mode 6 --dofs rc01,zs01 --points 3 \
  --span 0.03 --span2 0.03 --surfaces 0.35,0.65 \
  --nphi 51 --nalpha 11 --n-bounce 15 \
  --mirror-ntheta 32 --mirror-nphi 32 \
  --elongation-ntheta 24 --elongation-nphi 8

The landscape command defaults to trial solves for speed. Add --exact-solve before using the scanned QI, mirror, elongation, or iota values as promotion evidence.

Run the current reference-family preconditioner directly:

PYTHONPATH=. JAX_PLATFORMS=cpu VMEC_JAX_QI_RUN_CASE=qi_stel_seed_3127 \
  python examples/optimization/QI_optimization.py
PYTHONPATH=. JAX_PLATFORMS=cpu python tools/diagnostics/qi_boundary_interpolation_scan.py \
  --seed-input examples/data/input.QI_stel_seed_3127 \
  --reference-input examples/data/input.nfp3_QI_fixed_resolution_final \
  --out-root results/diagnostics/qi_seed3127_boundary_interpolation \
  --lambdas 0.994,0.995,0.996,0.997,0.998,0.999,1.0,1.001,1.002 \
  --max-mode 4 --max-iter 80 --target-aspect 4.0 \
  --surfaces 0.1,0.28,0.46,0.64,0.82,1.0 \
  --mboz 18 --nboz 18 --nphi 151 --nalpha 31 --n-bounce 51 \
  --smooth-qi-max 5e-3 --legacy-qi-max 2e-3 \
  --max-mirror-ratio 0.35 --max-elongation 8.0

Regenerate the README panels and the compact CSV used for the table:

PYTHONPATH=. python examples/optimization/render_readme_best_optimizations.py

Performance vs parity

  • Default runs select the fastest stable path for each input automatically.
  • Use --parity (or performance_mode=False in Python) to force the conservative VMEC2000 loop.
  • Use --solver-mode accelerated to force the optimized fixed-boundary controller.
  • For GPU benchmarking, separate raw solver throughput from public policy overhead. For example, use tools/diagnostics/profile_fixed_boundary.py --no-auto-cli-policy --solver-mode accelerated --no-multigrid --use-scan --solver-device gpu.
  • Compare both first-process and in-process warm timings. The first GPU process pays XLA/runtime setup; persistent cache effectiveness depends on the JAX version, backend, and machine features.

Details, profiling guidance, and parity methodology:

  • docs/performance.rst
  • docs/validation.rst
  • tools/diagnostics/parity_manifest.toml + tools/diagnostics/parity_sweep_manifest.py

CLI reference

vmec_jax input.*                run the equilibrium solver → wout_*.nc
vmec_jax --plot wout.nc         generate diagnostic plots (4 output files)
vmec_jax --parity input.*       force conservative VMEC2000 loop
vmec_jax --help                 full option list

VMEC++ notes

The current runtime benchmark compares vmec_jax against VMEC2000. VMEC++ is not included in this benchmark.

When VMEC++ is available, it can be added to the runtime plot via --cpu-summary entries with backend=vmecpp. Some inputs are not supported or do not converge under the same single-grid settings:

VMEC++ unsupported inputs (lasym=True):

  • LandremanSenguptaPlunk_section5p3_low_res
  • basic_non_stellsym_pressure
  • cth_like_free_bdy_lasym_small
  • up_down_asymmetric_tokamak

VMEC++ known non-convergence on these lasym=False cases under the same single-grid settings:

  • DIII-D_lasym_false
  • LandremanPaul2021_QA_reactorScale_lowres
  • LandremanPaul2021_QH_reactorScale_lowres
  • LandremanSengupta2019_section5.4_B2_A80
  • cth_like_fixed_bdy

CLI output and NSTEP

The VMEC-style iteration loop prints every NSTEP iterations. Larger NSTEP means fewer print callbacks and faster runs.

To disable live printing:

export VMEC_JAX_SCAN_PRINT=0

Quiet runs (--quiet or verbose=False) default the scan path to minimal history mode to reduce host/device traffic. Override with:

export VMEC_JAX_SCAN_MINIMAL=0  # keep full scan diagnostics even when quiet

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vmec_jax-0.0.8.tar.gz (880.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

vmec_jax-0.0.8-py3-none-any.whl (577.3 kB view details)

Uploaded Python 3

File details

Details for the file vmec_jax-0.0.8.tar.gz.

File metadata

  • Download URL: vmec_jax-0.0.8.tar.gz
  • Upload date:
  • Size: 880.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for vmec_jax-0.0.8.tar.gz
Algorithm Hash digest
SHA256 2356a3d6d85496920d550ee2c4d6f82da9677bd6b46a77ee7b5455859adf114f
MD5 2126feab47f9c074d33a2704bc91e9b0
BLAKE2b-256 ae9303c3d92dabb6d5bc9324ef4d21636aedf9eeaeccaee62399f64496092678

See more details on using hashes here.

Provenance

The following attestation bundles were made for vmec_jax-0.0.8.tar.gz:

Publisher: publish-pypi.yml on uwplasma/vmec_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file vmec_jax-0.0.8-py3-none-any.whl.

File metadata

  • Download URL: vmec_jax-0.0.8-py3-none-any.whl
  • Upload date:
  • Size: 577.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for vmec_jax-0.0.8-py3-none-any.whl
Algorithm Hash digest
SHA256 1b969466bcf86c90596fa0da54490d39c95793f7839232fb6208898cba5caeb1
MD5 297430e3a292e4de1b0363d544174ecb
BLAKE2b-256 b0a339352bd5c1e7a5a84ae89dc8389437da1cfff39f3ccff953c12268651068

See more details on using hashes here.

Provenance

The following attestation bundles were made for vmec_jax-0.0.8-py3-none-any.whl:

Publisher: publish-pypi.yml on uwplasma/vmec_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page