End-to-end differentiable JAX implementation of VMEC2000 for fixed and free-boundary equilibria.
Project description
vmec-jax
End-to-end differentiable JAX implementation of VMEC2000 for fixed-boundary and free-boundary ideal-MHD equilibria.
Install
pip install vmec-jax
QI optimization uses booz_xform_jax for the differentiable Boozer transform:
pip install "vmec-jax[qi]"
Developer (editable) install:
git clone https://github.com/uwplasma/vmec_jax
pip install -e "vmec_jax[qi]"
Usage
Run the solver (VMEC2000-style CLI):
vmec_jax input.nfp4_QH_warm_start # → wout_nfp4_QH_warm_start.nc
Generate diagnostic plots from any wout_*.nc (four-panel output, replicates vmecPlot2.py):
vmec_jax --plot wout_nfp4_QH_warm_start.nc # saves in same directory
vmec_jax --plot wout_nfp4_QH_warm_start.nc --outdir figures/
From Python:
import vmec_jax as vj
# Run a fixed-boundary solve
run = vj.run_fixed_boundary("input.nfp4_QH_warm_start")
# Run a free-boundary solve
freeb = vj.run_free_boundary("input.cth_like_free_bdy_lasym_small")
# Plot any wout file (produces *_VMECparams.pdf, *_poloidal_plot.png, *_VMECsurfaces.pdf, *_VMEC_3Dplot.png)
vj.plot_wout("wout_nfp4_QH_warm_start.nc", outdir="figures/")
Choosing CPU or GPU
vmec_jax follows the JAX backend you select. If you installed CPU-only JAX,
runs use CPU. If you installed GPU-enabled JAX and select a GPU backend, runs
use GPU; vmec_jax does not silently force those runs back to CPU.
# Check what JAX will use.
python -c "import jax; print(jax.default_backend()); print(jax.devices())"
# Force CPU for one command.
JAX_PLATFORMS=cpu vmec_jax input.nfp4_QH_warm_start
# Force an accelerator backend after installing GPU-enabled JAX.
JAX_PLATFORM_NAME=gpu vmec_jax input.nfp4_QH_warm_start
# For NVIDIA CUDA specifically, this is also valid.
JAX_PLATFORMS=cuda vmec_jax input.nfp4_QH_warm_start
From Python, leave solver_device unset to inherit JAX's default backend, or
pass solver_device="cpu" / solver_device="gpu" explicitly:
import vmec_jax as vj
run_gpu = vj.run_fixed_boundary("input.nfp4_QH_warm_start", solver_device="gpu")
run_cpu = vj.run_fixed_boundary("input.nfp4_QH_warm_start", solver_device="cpu")
For GPU runs, vmec_jax defaults XLA_PYTHON_CLIENT_PREALLOCATE=false before
JAX import so the allocator grows on demand. This avoids GPU memory contention
between optimization workers and was faster in the exact-Jacobian GPU profile.
Set XLA_PYTHON_CLIENT_PREALLOCATE=true before import if you explicitly want
JAX's default preallocation behavior.
vmec_jax enables JAX's persistent compilation cache automatically for
accelerator-selected runs, including runs where CUDA_VISIBLE_DEVICES or the
ROCm equivalents expose an accelerator before import. CPU cache use is explicit
opt-in because XLA:CPU AOT cache hits can emit host-feature mismatch errors on
some JAX versions. Set
VMEC_JAX_COMPILATION_CACHE=1 to enable the default cache for CPU runs, set
VMEC_JAX_COMPILATION_CACHE=0 to disable it, or set
VMEC_JAX_COMPILATION_CACHE_DIR=/path/to/cache to choose a custom location.
The default cache path is scoped by machine, CPU features, Python version, and
JAX/JAXLIB versions.
For the current small/medium fixed-boundary examples, CPU is often faster after JIT warmup. GPU support is production-enabled and useful to profile, but the exact optimizer now uses GPU-specific scan exact callbacks only when a GPU is actually selected. See the performance guide for current CPU/GPU timings and profiling commands.
Showcase (single-grid)
All figures below use the same single-grid run settings: NS_ARRAY=151, NITER_ARRAY=5000, FTOL_ARRAY=1e-14, NSTEP=500.
ITERModel cross-section (VMEC2000 vs vmec_jax) |
LandremanPaul2021_QA_lowres cross-section (VMEC2000 vs vmec_jax) |
ITERModel iota (VMEC2000 vs vmec_jax) |
LandremanPaul2021_QA_lowres iota (VMEC2000 vs vmec_jax) |
ITERModel 3D LCFS |
LandremanPaul2021_QA_lowres 3D LCFS |
ITERModel |B| on LCFS |
LandremanPaul2021_QA_lowres |B| on LCFS |
Cold vs warm runtime: the cold bar includes XLA JIT compilation on the first call (one-time cost per process); the warm bar is the steady-state solve time for subsequent calls in the same process. VMEC2000 has no compilation overhead, so it is always effectively cold. vmec_jax uses JAX's persistent compilation cache automatically for accelerator-selected runs under ~/.cache/vmec_jax/jax_cache/<machine-fingerprint>. CPU cache use is opt-in with VMEC_JAX_COMPILATION_CACHE=1 to avoid XLA:CPU AOT host-feature mismatch warnings on some JAX versions.
Best Stellarator-Symmetric Optimizations
The fixed-boundary optimization examples solve VMEC equilibria and differentiate
the objective with the exact discrete-adjoint/tape path. The README only shows
one current best LASYM = F result for each target; the full CPU/GPU policy
matrix, LASYM panels, finite-beta examples, QI constraint sweep, and all tables
live in the
optimization guide and
optimization sweep results.
Each row below shows the original deck LCFS before any max_mode=1
optimization work, the final LCFS, per-stage objective history, and the final
outer-surface |B| in Boozer coordinates computed with booz_xform_jax.
This sweep uses NFP=2 seeds for QA/QP/QI and the standard bundled NFP=4 warm
start for QH. The current objective priority is primary symmetry/QI quality
and rotational-transform control. QA follows the reference omnigenity QA deck
with aspect ratio near 5 and signed mean iota target 0.42; QH/QP/QI also use
aspect ratio near 5 and abs(mean_iota) >= 0.41. LgradB remains available
as an optional script-level term, but it is not active in the default README
examples or best-row selection.
The QP and QI rows both start from the bundled NFP=2 QI seed. QP is a
quasi-poloidal-symmetry target using that same input deck; the current best QI
row uses repeated same-mode continuation at max_mode=3 without a QP preseed.
The bundled NFP=2 seed is projected to each active max_mode, so
max_mode=1 zeroes the seed's mode-2 boundary harmonics before optimizing.
For QI, the listed wall time includes all repeated stages using the same
constrained least-squares residual definition.
| Target | Backend | Policy | max_mode | ESS | QP preseed | Final J | QI legacy | Mirror | Elong. | Aspect | Iota | Wall time |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| QA | CPU | continuation | 3 | yes | 2.33e-04 | 5.000 | 0.4200 | 6.1 min | ||||
| QH | CPU | continuation | 3 | yes | 9.68e-03 | 4.999 | -1.6595 | 4.0 min | ||||
| QP | CPU | continuation | 3 | no | 6.76e-02 | 5.019 | -0.6255 | 3.7 min | ||||
| QI | CPU | continuation | 3 | yes | no | 1.05e-03 | 1.04e-03 | 0.211 | 4.78 | 5.000 | -0.4553 | 6.6 min |
Recreate the four displayed runs:
PYTHONPATH=. JAX_PLATFORMS=cpu python examples/optimization/generate_qs_ess_sweep.py --backend-label cpu --solver-device cpu --policy continuation --problems qa --modes 3 --ess off
PYTHONPATH=. JAX_PLATFORMS=cpu python examples/optimization/generate_qs_ess_sweep.py --backend-label cpu --solver-device cpu --policy continuation --problems qh --modes 3 --ess on
PYTHONPATH=. JAX_PLATFORMS=cpu python examples/optimization/generate_qs_ess_sweep.py --backend-label cpu --solver-device cpu --policy continuation --problems qp --modes 3 --ess off
PYTHONPATH=. JAX_PLATFORMS=cpu python examples/optimization/QI_optimization.py
Regenerate the README panels and the compact CSV used for the table:
PYTHONPATH=. python examples/optimization/render_readme_best_optimizations.py
Performance vs parity
- Default runs select the fastest stable path for each input automatically.
- Use
--parity(orperformance_mode=Falsein Python) to force the conservative VMEC2000 loop. - Use
--solver-mode acceleratedto force the optimized fixed-boundary controller. - For GPU benchmarking, separate raw solver throughput from public policy overhead. For example, use
tools/diagnostics/profile_fixed_boundary.py --no-auto-cli-policy --solver-mode accelerated --no-multigrid --use-scan --solver-device gpu. - Compare both first-process and in-process warm timings. The first GPU process pays XLA/runtime setup; persistent cache effectiveness depends on the JAX version, backend, and machine features.
Details, profiling guidance, and parity methodology:
docs/performance.rstdocs/validation.rsttools/diagnostics/parity_manifest.toml+tools/diagnostics/parity_sweep_manifest.py
CLI reference
vmec_jax input.* run the equilibrium solver → wout_*.nc
vmec_jax --plot wout.nc generate diagnostic plots (4 output files)
vmec_jax --parity input.* force conservative VMEC2000 loop
vmec_jax --help full option list
VMEC++ notes
The current runtime benchmark compares vmec_jax against VMEC2000. VMEC++ is not included in this benchmark.
When VMEC++ is available, it can be added to the runtime plot via --cpu-summary entries with backend=vmecpp. Some inputs are not supported or do not converge under the same single-grid settings:
VMEC++ unsupported inputs (lasym=True):
LandremanSenguptaPlunk_section5p3_low_resbasic_non_stellsym_pressurecth_like_free_bdy_lasym_smallup_down_asymmetric_tokamak
VMEC++ known non-convergence on these lasym=False cases under the same single-grid settings:
DIII-D_lasym_falseLandremanPaul2021_QA_reactorScale_lowresLandremanPaul2021_QH_reactorScale_lowresLandremanSengupta2019_section5.4_B2_A80cth_like_fixed_bdy
CLI output and NSTEP
The VMEC-style iteration loop prints every NSTEP iterations. Larger NSTEP means fewer print callbacks and faster runs.
To disable live printing:
export VMEC_JAX_SCAN_PRINT=0
Quiet runs (--quiet or verbose=False) default the scan path to minimal history
mode to reduce host/device traffic. Override with:
export VMEC_JAX_SCAN_MINIMAL=0 # keep full scan diagnostics even when quiet
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file vmec_jax-0.0.7.tar.gz.
File metadata
- Download URL: vmec_jax-0.0.7.tar.gz
- Upload date:
- Size: 618.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d788e8e450ea64c5af275b12fe14405b14b29d0e82dc8dac67e52d2277e50d04
|
|
| MD5 |
a9b3b9a6b8b1ce7e1406280063458ac9
|
|
| BLAKE2b-256 |
da258488c7fe56dac246a077f124fbaa767019c6a8515fafd63835d97d8cdf0d
|
Provenance
The following attestation bundles were made for vmec_jax-0.0.7.tar.gz:
Publisher:
publish-pypi.yml on uwplasma/vmec_jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
vmec_jax-0.0.7.tar.gz -
Subject digest:
d788e8e450ea64c5af275b12fe14405b14b29d0e82dc8dac67e52d2277e50d04 - Sigstore transparency entry: 1487433616
- Sigstore integration time:
-
Permalink:
uwplasma/vmec_jax@ecc6a85670c362748376af31772117634eab0354 -
Branch / Tag:
refs/tags/v0.0.7 - Owner: https://github.com/uwplasma
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-pypi.yml@ecc6a85670c362748376af31772117634eab0354 -
Trigger Event:
release
-
Statement type:
File details
Details for the file vmec_jax-0.0.7-py3-none-any.whl.
File metadata
- Download URL: vmec_jax-0.0.7-py3-none-any.whl
- Upload date:
- Size: 526.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9e186aa45a177978a2c271417b8e1c932765922750bdcc0165fbb3753f67aeb4
|
|
| MD5 |
f9882654f87d9f71daa26a8fbb6c1afe
|
|
| BLAKE2b-256 |
925965eca1d828ae6c76feec5fe3788a2c8e16dc0f20913075ea2f1fa7ccc798
|
Provenance
The following attestation bundles were made for vmec_jax-0.0.7-py3-none-any.whl:
Publisher:
publish-pypi.yml on uwplasma/vmec_jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
vmec_jax-0.0.7-py3-none-any.whl -
Subject digest:
9e186aa45a177978a2c271417b8e1c932765922750bdcc0165fbb3753f67aeb4 - Sigstore transparency entry: 1487433716
- Sigstore integration time:
-
Permalink:
uwplasma/vmec_jax@ecc6a85670c362748376af31772117634eab0354 -
Branch / Tag:
refs/tags/v0.0.7 - Owner: https://github.com/uwplasma
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-pypi.yml@ecc6a85670c362748376af31772117634eab0354 -
Trigger Event:
release
-
Statement type: