Skip to main content

End-to-end differentiable JAX implementation of VMEC2000 for fixed and free-boundary equilibria.

Project description

vmec-jax

PyPI version Python License CI Coverage Docs PyPI downloads

End-to-end differentiable JAX implementation of VMEC2000 for fixed-boundary and free-boundary ideal-MHD equilibria.

Install

pip install vmec-jax

Developer (editable) install:

git clone https://github.com/uwplasma/vmec_jax
pip install -e vmec_jax/

Usage

Run the solver (VMEC2000-style CLI):

vmec_jax input.nfp4_QH_warm_start        # → wout_nfp4_QH_warm_start.nc

Generate diagnostic plots from any wout_*.nc (four-panel output, replicates vmecPlot2.py):

vmec_jax --plot wout_nfp4_QH_warm_start.nc           # saves in same directory
vmec_jax --plot wout_nfp4_QH_warm_start.nc --outdir figures/

From Python:

import vmec_jax as vj

# Run a fixed-boundary solve
run = vj.run_fixed_boundary("input.nfp4_QH_warm_start")

# Run a free-boundary solve
freeb = vj.run_free_boundary("input.cth_like_free_bdy_lasym_small")

# Plot any wout file (produces *_VMECparams.pdf, *_poloidal_plot.png, *_VMECsurfaces.pdf, *_VMEC_3Dplot.png)
vj.plot_wout("wout_nfp4_QH_warm_start.nc", outdir="figures/")

Run tests:

pytest -q

Showcase (single-grid)

All figures below use the same single-grid run settings: NS_ARRAY=151, NITER_ARRAY=5000, FTOL_ARRAY=1e-14, NSTEP=500.

ITERModel cross-section (VMEC2000 vs vmec_jax) LandremanPaul2021_QA_lowres cross-section (VMEC2000 vs vmec_jax)
ITERModel iota (VMEC2000 vs vmec_jax) LandremanPaul2021_QA_lowres iota (VMEC2000 vs vmec_jax)
ITERModel 3D LCFS LandremanPaul2021_QA_lowres 3D LCFS
ITERModel |B| on LCFS LandremanPaul2021_QA_lowres |B| on LCFS

Cold vs warm runtime: the cold bar includes XLA JIT compilation on the first call (one-time cost per process); the warm bar is the steady-state solve time for subsequent calls in the same process. VMEC2000 has no compilation overhead — it is always "cold". The warm vmec_jax time is the fair comparison for repeated solves (e.g., in an optimization loop). vmec_jax automatically caches compiled XLA kernels to disk (~/.cache/vmec_jax/jax_cache), so after the first run cold starts also approach warm speed.

Quasi-helical symmetry optimization (discrete-adjoint)

examples/optimization/qh_fixed_resolution_jax.py demonstrates an end-to-end fixed-boundary QH optimization using the built-in exact discrete-adjoint Jacobian — no finite differences, no SIMSOPT dependency.

Discrete adjoint: rather than perturbing each boundary DOF separately (finite differences), vmec_jax records a checkpoint tape of the VMEC iteration and propagates all parameter tangents through it in one batched forward pass (jax.vmap(jax.jvp(...))). The Jacobian is exact (machine precision) and its cost is roughly 1–2 forward solves regardless of the number of DOFs — vs. n_DOFs forward solves for finite differences. → Detailed explanation · SIMSOPT comparison

python examples/optimization/qh_fixed_resolution_jax.py   # MAX_MODE=2 by default

When max_mode exceeds the modes present in the input file, vmec_jax automatically extends the boundary to include the requested harmonics at zero amplitude (vj.extend_boundary_for_max_mode), matching SIMSOPT's fixed_range() behaviour. All runs use consistent VMEC resolution mpol = ntor = 5 so the initial QS metric is normalised identically across max_mode values.

max_mode DOFs QS initial QS final Reduction Wall time ¹
1 8 0.303 0.213 30 % ~124 s
2 24 0.303 0.008 97 % ~323 s

¹ Wall time on Apple M-series (warm-cache subsequent runs are faster).

With only 8 DOFs (max_mode=1) the boundary deformation space is too limited to reach a deep quasi-helical minimum. max_mode=2 (24 DOFs) achieves a 97 % reduction because the higher harmonics give the optimizer room to reshape the boundary helically.

vmec_jax vs SIMSOPT: vmec_jax uses an exact discrete-adjoint Jacobian (one batched JVP pass ≈ 1–2 forward solves regardless of DOF count) while SIMSOPT + VMEC2000 uses finite differences (n_DOFs × 1 forward solve per Jacobian). For a detailed comparison of algorithms, runtimes, and memory, see docs/simsopt_comparison.rst.

max_mode = 1  (8 DOFs, 30 % QS reduction) max_mode = 2  (24 DOFs, 97 % QS reduction)

The |B| contour plots show quasi-helical alignment after optimization: contour lines become increasingly helical (aligned with m θ − n φ = const). The ζ axis spans one field period (0 → 2π/nfp).

Regenerate plots after running the optimization:

python examples/optimization/plot_qh_optimization_results.py --output-dir results/qh_opt

Quasi-axisymmetric optimization (fixed-boundary)

examples/optimization/qa_fixed_resolution_jax_ess.py optimizes an nfp=2 QA equilibrium for aspect ratio, mean iota, and QA symmetry residuals.

python examples/optimization/qa_fixed_resolution_jax_ess.py   # MAX_MODE=2 by default

When max_mode exceeds the modes in the input file, vmec_jax automatically extends the boundary to include those harmonics at zero amplitude (vj.extend_boundary_for_max_mode). All runs use consistent VMEC resolution mpol = ntor = 5. Objectives: aspect ratio (target 6.0) + mean iota (target 0.41) + QA symmetry residuals. The optimization history shows three panels: QS residuals, aspect ratio, and mean iota.

max_mode DOFs Aspect initial → final Mean iota initial → final Wall time ¹
1 8 5.0 → 6.0 0 → 0 (axisymmetric DOFs) ~23 s
2 24 5.0 → 5.51 0 → 0.14 (3D modes) ~608 s

¹ Wall time on Apple M-series (warm-cache subsequent runs are faster).

With 8 DOFs (max_mode=1) only axisymmetric (n=0) harmonics are free, so the optimizer hits the aspect ratio target (5.0 → 6.0) but cannot generate rotational transform — iota stays at 0. max_mode=2 (24 DOFs) unlocks 3D modes that generate iota (0 → 0.14 toward target 0.41) while partially improving aspect ratio, at the cost of introducing mild QS breaking.

max_mode = 1  (8 DOFs, aspect hit target) max_mode = 2  (24 DOFs, iota 0→0.14)

Performance vs parity

  • Default runs select the fastest stable path for each input automatically.
  • Use --parity (or performance_mode=False in Python) to force the conservative VMEC2000 loop.
  • Use --solver-mode accelerated to force the optimized fixed-boundary controller.

Details, profiling guidance, and parity methodology:

  • docs/performance.rst
  • docs/validation.rst
  • tools/diagnostics/parity_manifest.toml + tools/diagnostics/parity_sweep_manifest.py

CLI reference

vmec_jax input.*                run the equilibrium solver → wout_*.nc
vmec_jax --plot wout.nc         generate diagnostic plots (4 output files)
vmec_jax --parity input.*       force conservative VMEC2000 loop
vmec_jax --help                 full option list

VMEC++ notes

The current runtime benchmark compares vmec_jax against VMEC2000. VMEC++ is not included in this benchmark.

When VMEC++ is available, it can be added to the runtime plot via --cpu-summary entries with backend=vmecpp. Some inputs are not supported or do not converge under the same single-grid settings:

VMEC++ unsupported inputs (lasym=True):

  • LandremanSenguptaPlunk_section5p3_low_res
  • basic_non_stellsym_pressure
  • cth_like_free_bdy_lasym_small
  • up_down_asymmetric_tokamak

VMEC++ known non-convergence on these lasym=False cases under the same single-grid settings:

  • DIII-D_lasym_false
  • LandremanPaul2021_QA_reactorScale_lowres
  • LandremanPaul2021_QH_reactorScale_lowres
  • LandremanSengupta2019_section5.4_B2_A80
  • cth_like_fixed_bdy

CLI output and NSTEP

The VMEC-style iteration loop prints every NSTEP iterations. Larger NSTEP means fewer print callbacks and faster runs.

To disable live printing:

export VMEC_JAX_SCAN_PRINT=0

Quiet runs (--quiet or verbose=False) default the scan path to minimal history mode to reduce host/device traffic. Override with:

export VMEC_JAX_SCAN_MINIMAL=0  # keep full scan diagnostics even when quiet

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vmec_jax-0.0.2.tar.gz (524.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

vmec_jax-0.0.2-py3-none-any.whl (465.8 kB view details)

Uploaded Python 3

File details

Details for the file vmec_jax-0.0.2.tar.gz.

File metadata

  • Download URL: vmec_jax-0.0.2.tar.gz
  • Upload date:
  • Size: 524.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for vmec_jax-0.0.2.tar.gz
Algorithm Hash digest
SHA256 a963dd09cc7f7d55a198c5c6ccb422321460f359a5a5bef322e065e77e5de564
MD5 ac09ccf5e3a08dc184bf86185e72dbe7
BLAKE2b-256 96b3499fe90ac7cd415483e3917526aed484f9243763836064e01117aaee114d

See more details on using hashes here.

Provenance

The following attestation bundles were made for vmec_jax-0.0.2.tar.gz:

Publisher: publish-pypi.yml on uwplasma/vmec_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file vmec_jax-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: vmec_jax-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 465.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for vmec_jax-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a05c8082d6a2d512d8f783845bf28516998b1118c8048ff0627355c3edaea95e
MD5 038abf05c0427adf7b10442bf10339b1
BLAKE2b-256 d9bf096f9f1e0e7a8ec4f0e8bc724e1db22d9c444886bf31b77e6402c5e1e466

See more details on using hashes here.

Provenance

The following attestation bundles were made for vmec_jax-0.0.2-py3-none-any.whl:

Publisher: publish-pypi.yml on uwplasma/vmec_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page