End-to-end differentiable JAX implementation of VMEC2000 for fixed and free-boundary equilibria.
Project description
vmec-jax
End-to-end differentiable JAX implementation of VMEC2000 for fixed-boundary and free-boundary ideal-MHD equilibria.
Install
pip install vmec-jax
QI optimization uses booz_xform_jax for the differentiable Boozer transform:
pip install "vmec-jax[qi]"
Developer (editable) install:
git clone https://github.com/uwplasma/vmec_jax
pip install -e "vmec_jax[qi]"
Usage
Run the solver (VMEC2000-style CLI):
vmec_jax input.nfp4_QH_warm_start # → wout_nfp4_QH_warm_start.nc
Generate diagnostic plots from any wout_*.nc (four-panel output, replicates vmecPlot2.py):
vmec_jax --plot wout_nfp4_QH_warm_start.nc # saves in same directory
vmec_jax --plot wout_nfp4_QH_warm_start.nc --outdir figures/
From Python:
import vmec_jax as vj
# Run a fixed-boundary solve
run = vj.run_fixed_boundary("input.nfp4_QH_warm_start")
# Run a free-boundary solve
freeb = vj.run_free_boundary("input.cth_like_free_bdy_lasym_small")
# Plot any wout file (produces *_VMECparams.pdf, *_poloidal_plot.png, *_VMECsurfaces.pdf, *_VMEC_3Dplot.png)
vj.plot_wout("wout_nfp4_QH_warm_start.nc", outdir="figures/")
Run tests:
pytest -q
Choosing CPU or GPU
vmec_jax follows the JAX backend you select. If you installed CPU-only JAX,
runs use CPU. If you installed GPU-enabled JAX and select a GPU backend, runs
use GPU; vmec_jax does not silently force those runs back to CPU.
# Check what JAX will use.
python -c "import jax; print(jax.default_backend()); print(jax.devices())"
# Force CPU for one command.
JAX_PLATFORMS=cpu vmec_jax input.nfp4_QH_warm_start
# Force an accelerator backend after installing GPU-enabled JAX.
JAX_PLATFORM_NAME=gpu vmec_jax input.nfp4_QH_warm_start
# For NVIDIA CUDA specifically, this is also valid.
JAX_PLATFORMS=cuda vmec_jax input.nfp4_QH_warm_start
From Python, leave solver_device unset to inherit JAX's default backend, or
pass solver_device="cpu" / solver_device="gpu" explicitly:
import vmec_jax as vj
run_gpu = vj.run_fixed_boundary("input.nfp4_QH_warm_start", solver_device="gpu")
run_cpu = vj.run_fixed_boundary("input.nfp4_QH_warm_start", solver_device="cpu")
For GPU runs, vmec_jax defaults XLA_PYTHON_CLIENT_PREALLOCATE=false before
JAX import so the allocator grows on demand. This avoids GPU memory contention
between optimization workers and was faster in the exact-Jacobian GPU profile.
Set XLA_PYTHON_CLIENT_PREALLOCATE=true before import if you explicitly want
JAX's default preallocation behavior.
vmec_jax enables JAX's persistent compilation cache by default, but its
default cache path is machine/CPU-feature scoped to avoid reusing CPU AOT
executables compiled on a different host. Set VMEC_JAX_COMPILATION_CACHE=0 to
disable the persistent cache or VMEC_JAX_COMPILATION_CACHE_DIR=/path/to/cache
to choose a custom location.
Showcase (single-grid)
All figures below use the same single-grid run settings: NS_ARRAY=151, NITER_ARRAY=5000, FTOL_ARRAY=1e-14, NSTEP=500.
ITERModel cross-section (VMEC2000 vs vmec_jax) |
LandremanPaul2021_QA_lowres cross-section (VMEC2000 vs vmec_jax) |
ITERModel iota (VMEC2000 vs vmec_jax) |
LandremanPaul2021_QA_lowres iota (VMEC2000 vs vmec_jax) |
ITERModel 3D LCFS |
LandremanPaul2021_QA_lowres 3D LCFS |
ITERModel |B| on LCFS |
LandremanPaul2021_QA_lowres |B| on LCFS |
Cold vs warm runtime: the cold bar includes XLA JIT compilation on the first call (one-time cost per process); the warm bar is the steady-state solve time for subsequent calls in the same process. VMEC2000 has no compilation overhead, so it is always effectively cold. vmec_jax enables JAX's persistent compilation cache by default under ~/.cache/vmec_jax/jax_cache/<machine-fingerprint> so repeated cold-process runs on the same host can reuse compiled kernels without sharing CPU AOT executables across incompatible machines; set VMEC_JAX_COMPILATION_CACHE=0 to disable it or VMEC_JAX_COMPILATION_CACHE_DIR=/path/to/cache to choose a different location.
Optimization Internals
The fixed-boundary optimization examples expose the problem construction directly in Python: VMEC resolution, active boundary coefficients, objective blocks, weights, continuation policy, ESS scaling, and the outer optimizer are all top-level variables in the scripts. No SIMSOPT wrapper layer is required.
For a boundary parameter vector x, vmec_jax solves the VMEC residual
F(y, x) = 0 for the equilibrium state y, then differentiates objective
residuals r(y(x), x) with the exact discrete-adjoint/tape path. Instead of
finite-differencing each boundary DOF, vmec_jax records a checkpoint tape of
the nonlinear VMEC iteration and replays it with JAX JVP/VJP rules. The dense
least-squares Jacobian used by the examples is exact to machine precision and
has cost comparable to a small number of forward solves, not one VMEC solve per
boundary DOF.
Details: discrete adjoint, optimization guide, and SIMSOPT comparison.
Individual Fixed-Boundary Optimization Scripts
The standalone examples teach the same workflow as the SIMSOPT examples without depending on SIMSOPT: choose VMEC resolution, active boundary coefficients, objective terms, weights, continuation policy, ESS scaling, and optimizer directly in Python. They are intentionally plain scripts with top-level variables, not argparse wrappers.
python examples/optimization/qa_fixed_resolution_jax_ess.py
python examples/optimization/qh_fixed_resolution_jax.py
python examples/optimization/qp_fixed_resolution_jax_ess.py
python examples/optimization/qi_fixed_resolution_jax_ess.py
Key top-level controls are VMEC_MPOL, VMEC_NTOR, MAX_MODE,
OBJECTIVES, METHOD, USE_MODE_CONTINUATION, USE_ESS, and ALPHA.
When MAX_MODE exceeds the modes present in the input file, vmec_jax extends
the boundary with vj.extend_boundary_for_max_mode, matching SIMSOPT's
fixed_range() workflow.
Add a new target by appending an objective term:
OBJECTIVES = [
aspect_objective(TARGET_ASPECT, ASPECT_WEIGHT),
abs_mean_iota_floor_objective(TARGET_ABS_IOTA_MIN, IOTA_WEIGHT),
quasisymmetry_objective(helicity_m=1, helicity_n=-1, surfaces=SURFACES, weight=QS_WEIGHT),
ObjectiveTerm("custom", lambda ctx, state: your_metric(ctx, state), target=0.0, weight=0.1),
]
For current QA/QH/QP/QI results, use the policy sweep below. It is the only optimization benchmark table shown in the README so stale single-script results do not drift from the generated sweep artifacts.
Finite-beta Stage-one Optimization
The finite-beta examples reproduce the VMEC-only stage-one finite-beta workflows without SIMSOPT and without coils. They use bundled finite-pressure and current-driven input decks and add differentiable residuals for aspect ratio, iota bounds, volume-averaged field proxy, total beta, plus the field quality objective (QA/QH quasisymmetry or QI).
python examples/optimization/qa_optimization_finite_beta.py
python examples/optimization/qh_optimization_finite_beta.py
python examples/optimization/qi_optimization_finite_beta.py
The scripts save input.initial, input.final, wout_initial.nc,
wout_final.nc, and history.json. Full differentiable Mercier DMerc and
Redl bootstrap-current mismatch residuals are the next finite-beta extensions;
the current examples keep the stage-one structure and current-profile support
in place so those terms can be added without changing the user workflow.
The QI script exposes QI_MBOZ, QI_NBOZ, QI_NPHI, QI_NALPHA, and
QI_N_BOUNCE at the top; the defaults are first-run diagnostic settings, and
should be increased for final research-quality QI refinements.
QA/QH/QP/QI Optimization Policy Sweep
The panel below compares the exact standalone optimizer on CPU and GPU for
four targets: QA, QH, QP, and QI. It includes the complete
stellarator-symmetric matrix and the currently available partial LASYM lanes.
Columns increase the boundary space from max_mode = 1 to max_mode = 4.
Rows compare staged mode continuation against direct-start mode expansion.
Blue curves use unscaled boundary DOFs; orange curves use ESS with
alpha = 2.5. Solid lines met the optimizer success criterion; dashed lines
mark stopped, failed, or budgeted lanes. Timeout/OOM details are recorded in
the summary tables.
The QA input includes 1e-5 seeds for the mode-1 boundary terms so the iota
residual has a useful direction. With the corrected bounded solve budgets, QA
continuation reaches the target-iota basin on both CPU and GPU. Direct QA with
ESS also reaches iota ~= 0.409; direct QA without ESS now leaves the zero-iota
branch for modes 2 and 3, but direct high-mode starts remain weak for mode 4.
QH and QP use the quasisymmetry residual with different helicities. QI uses
vmec_jax.quasi_isodynamic, a smooth Boozer-space residual built through
booz_xform_jax. The QI rows first run a same-mode QP preseed and then refine
with the QI residual; this avoids the QH warm-start basin and gives visibly
non-QH |B| contours while keeping the objective differentiable.
Final CPU states for the continuation and direct-start policies are shown
below. The |B| panels use line contours on the LCFS, with a separate colorbar
for each panel because the field ranges are not identical across aspect-ratio
changes.
Recreate the full CPU/GPU sweep:
PYTHONPATH=. JAX_PLATFORMS=cpu python examples/optimization/generate_qs_ess_sweep.py --backend-label cpu --solver-device cpu --policy continuation --problems qa,qh,qp,qi --modes 1,2,3,4 --ess both
PYTHONPATH=. JAX_PLATFORMS=cpu python examples/optimization/generate_qs_ess_sweep.py --backend-label cpu --solver-device cpu --policy direct --problems qa,qh,qp,qi --modes 1,2,3,4 --ess both
PYTHONPATH=. JAX_PLATFORM_NAME=gpu python examples/optimization/generate_qs_ess_sweep.py --backend-label gpu --solver-device gpu --policy continuation --problems qa,qh,qp,qi --modes 1,2,3,4 --ess both
PYTHONPATH=. JAX_PLATFORM_NAME=gpu python examples/optimization/generate_qs_ess_sweep.py --backend-label gpu --solver-device gpu --policy direct --problems qa,qh,qp,qi --modes 1,2,3,4 --ess both
The default per-case timeout is 1200 s. GPU sweeps use exact/replay callbacks
with calibrated optimizer budgets (inner_max_iter = trial_max_iter = 180,
ftol = trial_ftol = 1e-9 for deck-controlled QA/QH cases) so production
sweeps have enough room to converge high-mode/LASYM cases while still bounding
runaway rows. Add --diagnostic-budgets only when you explicitly want bounded
quick-look GPU diagnostics, and use --case-timeout-s 0 only for an unbounded
local diagnostic run.
Recreate just the CPU direct-start rows:
PYTHONPATH=. JAX_PLATFORMS=cpu python examples/optimization/generate_qs_ess_sweep.py --backend-label cpu --solver-device cpu --policy direct --problems qa,qh,qp,qi --modes 1,2,3,4 --ess both
Render the README/docs panels and tables:
PYTHONPATH=. python examples/optimization/render_qs_ess_publication_panel.py
To run the non-stellarator-symmetric matrix, append
--stellarator-asymmetric. This sets LASYM = T in memory, optimizes
RBC/ZBS/RBS/ZBC, seeds zero asymmetric RBS/ZBC modes with 1e-7, and
writes separate LASYM outputs under results/qs_ess_sweep/<backend>/asymmetric/.
The renderer then creates additional *_asymmetric_* objective, atlas, summary,
and publication panels.
PYTHONPATH=. JAX_PLATFORMS=cpu python examples/optimization/generate_qs_ess_sweep.py --backend-label cpu --solver-device cpu --policy continuation --problems qa,qh,qp,qi --modes 1,2,3,4 --ess both --stellarator-asymmetric
PYTHONPATH=. JAX_PLATFORMS=cpu python examples/optimization/generate_qs_ess_sweep.py --backend-label cpu --solver-device cpu --policy direct --problems qa,qh,qp,qi --modes 1,2,3,4 --ess both --stellarator-asymmetric
PYTHONPATH=. JAX_PLATFORM_NAME=gpu python examples/optimization/generate_qs_ess_sweep.py --backend-label gpu --solver-device gpu --policy continuation --problems qa,qh,qp,qi --modes 1,2,3,4 --ess both --stellarator-asymmetric
PYTHONPATH=. JAX_PLATFORM_NAME=gpu python examples/optimization/generate_qs_ess_sweep.py --backend-label gpu --solver-device gpu --policy direct --problems qa,qh,qp,qi --modes 1,2,3,4 --ess both --stellarator-asymmetric
PYTHONPATH=. python examples/optimization/render_qs_ess_publication_panel.py
The LASYM panels are published as a partial 1200 s snapshot. This is useful because the failures are informative: current mode-4 GPU LASYM lanes include timeout and GPU-memory limits in the exact tangent replay path. In the frozen snapshot used here, the partial LASYM table contains 13 CPU rows and 61 GPU rows. The CPU subset has 6 successful rows, 6 crashed rows, and 1 budgeted stop; the GPU subset has 19 successful rows, 10 crashed rows, and 32 budgeted stops.
For NVIDIA-only JAX installations, JAX_PLATFORMS=cuda is also valid. Do not
use JAX_PLATFORMS=gpu: some JAX versions interpret that as both CUDA and ROCm
and fail if ROCm is not installed.
Run individual examples by editing top-level variables in each script:
python examples/optimization/qa_fixed_resolution_jax_ess.py
python examples/optimization/qh_fixed_resolution_jax.py
python examples/optimization/qp_fixed_resolution_jax_ess.py
python examples/optimization/qi_fixed_resolution_jax_ess.py
More figures, CSV/JSON summaries, and reproduction notes are in docs/optimization_sweep_results.rst.
Best row per backend/problem/policy in the plotted symmetric sweep:
| Backend | Problem | Policy | Best max_mode | ESS | Status | Final J | Aspect | Iota | nfev | Wall min |
|---|---|---|---|---|---|---|---|---|---|---|
| CPU | QA | continuation | 4 | yes | stopped | 2.84e-05 | 5.999 | 0.4100 | 79 | 19.8 |
| CPU | QA | direct | 3 | yes | ok | 3.13e-05 | 6.000 | 0.4102 | 51 | 19.3 |
| CPU | QH | continuation | 4 | yes | ok | 5.87e-04 | 7.000 | -1.2182 | 46 | 18.6 |
| CPU | QH | direct | 3 | yes | ok | 3.27e-03 | 6.999 | - | 20 | 9.2 |
| CPU | QP | continuation | 4 | no | ok | 3.65e-02 | 7.002 | -0.4218 | 51 | 5.0 |
| CPU | QP | direct | 2 | yes | ok | 3.74e-02 | 7.004 | -0.4037 | 30 | 1.1 |
| CPU | QI | continuation | 4 | no | ok | 5.20e-03 | 7.002 | -0.4148 | 80 | 4.5 |
| CPU | QI | direct | 2 | yes | ok | 4.90e-03 | 7.001 | -0.5808 | 31 | 1.4 |
| GPU | QA | continuation | 4 | no | ok | 9.74e-05 | 6.000 | 0.4100 | 94 | 19.4 |
| GPU | QA | direct | 4 | yes | stopped | 6.77e-05 | 6.000 | 0.4100 | 60 | 15.1 |
| GPU | QH | continuation | 4 | yes | ok | 9.38e-04 | 7.000 | -1.2528 | 91 | 19.5 |
| GPU | QH | direct | 4 | yes | ok | 1.30e-03 | 7.000 | -1.2280 | 20 | 5.0 |
| GPU | QP | continuation | 4 | no | ok | 6.46e-02 | 7.011 | -0.4012 | 87 | 14.8 |
| GPU | QP | direct | 2 | yes | ok | 3.72e-02 | 7.005 | -0.4028 | 30 | 4.4 |
| GPU | QI | continuation | 4 | no | ok | 1.66e-03 | 7.000 | -0.4096 | 93 | 17.1 |
| GPU | QI | direct | 3 | yes | ok | 2.72e-03 | 6.999 | -0.4025 | 45 | 7.4 |
The generated CSV includes the complete 128-row symmetric CPU/GPU table plus
the current partial LASYM rows. Filter stellarator_asymmetric=False for the
symmetric benchmark subset:
docs/_static/figures/qs_ess_summary_all.csv.
Performance vs parity
- Default runs select the fastest stable path for each input automatically.
- Use
--parity(orperformance_mode=Falsein Python) to force the conservative VMEC2000 loop. - Use
--solver-mode acceleratedto force the optimized fixed-boundary controller. - For GPU benchmarking, compare both first-process and cache-warm timings; the first GPU process pays XLA compilation, while later processes reuse the persistent cache automatically.
Details, profiling guidance, and parity methodology:
docs/performance.rstdocs/validation.rsttools/diagnostics/parity_manifest.toml+tools/diagnostics/parity_sweep_manifest.py
CLI reference
vmec_jax input.* run the equilibrium solver → wout_*.nc
vmec_jax --plot wout.nc generate diagnostic plots (4 output files)
vmec_jax --parity input.* force conservative VMEC2000 loop
vmec_jax --help full option list
VMEC++ notes
The current runtime benchmark compares vmec_jax against VMEC2000. VMEC++ is not included in this benchmark.
When VMEC++ is available, it can be added to the runtime plot via --cpu-summary entries with backend=vmecpp. Some inputs are not supported or do not converge under the same single-grid settings:
VMEC++ unsupported inputs (lasym=True):
LandremanSenguptaPlunk_section5p3_low_resbasic_non_stellsym_pressurecth_like_free_bdy_lasym_smallup_down_asymmetric_tokamak
VMEC++ known non-convergence on these lasym=False cases under the same single-grid settings:
DIII-D_lasym_falseLandremanPaul2021_QA_reactorScale_lowresLandremanPaul2021_QH_reactorScale_lowresLandremanSengupta2019_section5.4_B2_A80cth_like_fixed_bdy
CLI output and NSTEP
The VMEC-style iteration loop prints every NSTEP iterations. Larger NSTEP means fewer print callbacks and faster runs.
To disable live printing:
export VMEC_JAX_SCAN_PRINT=0
Quiet runs (--quiet or verbose=False) default the scan path to minimal history
mode to reduce host/device traffic. Override with:
export VMEC_JAX_SCAN_MINIMAL=0 # keep full scan diagnostics even when quiet
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file vmec_jax-0.0.4.tar.gz.
File metadata
- Download URL: vmec_jax-0.0.4.tar.gz
- Upload date:
- Size: 588.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c3b3624f910776731d931a6a54813c01326167014e635889883ec378d0df5b19
|
|
| MD5 |
4dead54a004f2761ce308c3ca2fe952c
|
|
| BLAKE2b-256 |
62926b6814cf3629b2696b116e24f0d2eae48db47363cb920f52dca353f7d157
|
Provenance
The following attestation bundles were made for vmec_jax-0.0.4.tar.gz:
Publisher:
publish-pypi.yml on uwplasma/vmec_jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
vmec_jax-0.0.4.tar.gz -
Subject digest:
c3b3624f910776731d931a6a54813c01326167014e635889883ec378d0df5b19 - Sigstore transparency entry: 1414654505
- Sigstore integration time:
-
Permalink:
uwplasma/vmec_jax@276d99f73cd9b5158344e90e01756894e30857f9 -
Branch / Tag:
refs/tags/v0.0.4 - Owner: https://github.com/uwplasma
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-pypi.yml@276d99f73cd9b5158344e90e01756894e30857f9 -
Trigger Event:
release
-
Statement type:
File details
Details for the file vmec_jax-0.0.4-py3-none-any.whl.
File metadata
- Download URL: vmec_jax-0.0.4-py3-none-any.whl
- Upload date:
- Size: 505.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e5729dfca35ae66fbf692df3db864e46b2d934f5bf0bc42a52cb9a8ddc5ea16e
|
|
| MD5 |
00b5857b12286aee2e3430c25c595358
|
|
| BLAKE2b-256 |
80ac97367f7dd5631803f58ea33321d43c2b4d9c093131092d51a5823a612395
|
Provenance
The following attestation bundles were made for vmec_jax-0.0.4-py3-none-any.whl:
Publisher:
publish-pypi.yml on uwplasma/vmec_jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
vmec_jax-0.0.4-py3-none-any.whl -
Subject digest:
e5729dfca35ae66fbf692df3db864e46b2d934f5bf0bc42a52cb9a8ddc5ea16e - Sigstore transparency entry: 1414654641
- Sigstore integration time:
-
Permalink:
uwplasma/vmec_jax@276d99f73cd9b5158344e90e01756894e30857f9 -
Branch / Tag:
refs/tags/v0.0.4 - Owner: https://github.com/uwplasma
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-pypi.yml@276d99f73cd9b5158344e90e01756894e30857f9 -
Trigger Event:
release
-
Statement type: