Neoclassical transport code for effective ripple and trapped-particle diagnostics in Boozer coordinates
Project description
NEO_JAX
Install from PyPI:
pip install neo-jax
Neoclassical transport code for computing effective helical ripple and related trapped-particle diagnostics from Boozer-coordinate geometry, with a Python API, JAX acceleration, and an end-to-end differentiable pipeline through VMEC and Boozer transforms.
The full documentation lives in docs/ and covers the physics model,
numerics, geometry interfaces, runtime controls, applications, testing, and
source structure.
Quick start
pip install neo-jax
For development from a clone:
pip install -e .
neo-jax ORBITS --boozmn tests/fixtures/orbits/boozmn_ORBITS.nc --verbose
The package also installs legacy-compatible entrypoints:
xneo
xneo_jax
python -m neo_jax
If you prefer not to install, run from the repo root with:
PYTHONPATH=. python examples/ncsx_epsilon_effective_plot.py
CLI and file-based workflow
neo_jax now supports the same terminal workflow as STELLOPT's xneo,
including both the effective-ripple solve (calc_cur = 0) and the
parallel-current path (calc_cur = 1). The intent is simple: if you already
have a working NEO input deck, you should be able to point the same control
file at neo_jax and keep the same filenames, CLI invocation, and output-file
layout.
Legacy invocation:
xneo ORBITS
xneo_jax ORBITS
python -m neo_jax ORBITS
Unlike the STELLOPT binary, neo_jax prints explicit progress messages by
default so a long parity run does not look stalled. The CLI reports the control
file, Boozer file, solve mode, surface count, selected backend, and the active
JAX runtime, for example:
NEO_JAX: starting legacy CLI solve
NEO_JAX: control=neo_in.ORBITS
NEO_JAX: boozmn=boozmn_ORBITS.nc
NEO_JAX: mode=calc_cur=0 (epsilon effective)
NEO_JAX: surfaces=10 theta_n=64 phi_n=64 npart=40 backend=JAX
NEO_JAX: jax_runtime=gpu (2 devices: NVIDIA RTX A4000, NVIDIA RTX A4000)
Use --quiet if you want the legacy file outputs without the extra terminal
logging.
The progress output is now more detailed for both the CLI and the Python API:
- surface number and resolved flux index
s,sqrt(s), andiota- resolution (
theta_n,phi_n,npart,multra,nstep_*) - Boozer geometry summary (
nfp, number of modes,B00,Bmin,Bmax) - a preflight estimate of rational-surface work
Control-file lookup follows the same search order as STELLOPT:
neo_param.<extension>neo_param.inneo_in.<extension>- if no extension is given:
neo.in
For inp_swi = 0, the CLI follows the legacy boozmn_<extension>.nc
convention used by STELLOPT. For nonzero inp_swi, it resolves the input from
IN_FILE in the control file.
What the legacy CLI writes:
- main output:
neo_out.* - log file:
neolog.* - parallel-current summary when
calc_cur = 1:neo_cur.* - optional parallel-current integration history:
current.dat - optional diagnostic files:
diagnostic.dat,diagnostic_add.dat,diagnostic_bigint.dat - optional integration history:
conver.dat - optional geometry dumps:
dimension.dat,theta_arr.dat,phi_arr.dat,rmnc_arr.dat,bmnc_arr.dat,b_s_arr.dat, and the other legacy array files controlled byWRITE_OUTPUT_FILES
Compatibility scope:
- supported:
calc_cur = 0andcalc_cur = 1 - control-file lookup parity for
neo_param.<ext>,neo_param.in, andneo_in.<ext> - legacy Boozer filename resolution via
boozmn_<extension>.nc - default CLI progress logging, with
--quietavailable for benchmarking or silent batch runs - runtime backend reporting (
cpu,gpu, etc.) so long JAX compiles are clearly visible in terminal output
How the compatibility layer is implemented:
neo_jax/cli.pymirrors the legacy command-line contract and control-file search logic.neo_jax/legacy.pyreproduces the Fortran text formatting used inneo_out.*,neo_cur.*,neolog.*,diagnostic*.dat,current.dat, andconver.dat.neo_jax/driver.pywrites the same auxiliary files that STELLOPT writes whenWRITE_OUTPUT_FILES,WRITE_INTEGRATE,WRITE_DIAGNOSTIC, orWRITE_CUR_INTEare enabled.NEO_JAX_WRITE_IPMAX_DEBUG=1writesdiagnostic_ipmax_jax.dat, a per-step dump of the trapped-event amplitude used while buildingconver.dat.neo_jax/current.pyports the legacyflint_cur/calccurpath to JAX soneo_cur.*andcurrent.datare produced by the same CLI entrypoint.- The solver still calls the same JAX/Python backend used by the public API, so the terminal interface and Python interface stay numerically aligned.
How it is tested:
tests/regression/test_cli_legacy.pyruns the JAX CLI against committed reference outputs that were generated once with the STELLOPTxneoexecutable, so parity is checked in CI without requiring that external binary.tests/regression/test_gpu_smoke.pyadds optional CPU-vs-GPU smoke coverage for both the legacy CLI and the public Python API whenNEO_JAX_RUN_GPU=1is set on a machine with a visible JAX GPU backend.- The test suite covers:
- a real dense fixture:
LandremanPaul2021_QA_lowres - a synthetic one-surface ORBITS legacy case that exercises
neo_out,neolog,diagnostic*.dat,conver.dat, and all legacy array dumps - a one-surface ORBITS
calc_cur = 1case that checksneo_cur.*exactly andcurrent.dattoken-by-token against the STELLOPT executable - control-file precedence checks for
neo_param.<extension>andneo_param.in - optional slow full-fixture parity checks for
ORBITS_FASTandncsx_c09r00_free_fastwhenNEO_JAX_RUN_SLOW=1
- a real dense fixture:
Current comparison status:
neo_out.*,neo_cur.*,diagnostic*.dat, andneolog.*match the storedxneoreference text output exactly on the legacy parity cases.conver.datmatches exactly on the synthetic ORBITS parity case. On the fullORBITS_FASTfixture, columns 1-4 match exactly; the fifth column is traced withdiagnostic_ipmax_jax.datbecause the STELLOPT binary's text output on that dense case does not follow its own tracedaditot / p_bm2state.current.datmatches token-by-token, includingNaN/Infinityplacement, with tight floating-point tolerances to account for backend-level roundoff in the intermediate current history.- The legacy array dumps (
*_arr.dat,dimension.dat,theta_arr.dat,phi_arr.dat) are numerically identical to within floating-point roundoff. - GPU execution is now validated separately on the
officeworkstation for bothpython -m neo_jaxand the Python API; see the GPU table below.
Low-|iota| Surfaces
For surfaces with very small rotational transform |iota|, the legacy NEO
rational-surface correction can become extremely expensive:
nfp_rat ~= ceil(1 / acc_req / |iota|)
Physically, very small |iota| means a field line can require many field
periods to sample the rational structure. Numerically, that makes the legacy
correction cost explode and a run can look stalled even though it is still
making progress.
NEO_JAX therefore preflights the estimated rational-correction workload:
- default guarded behavior:
NeoConfig(max_rational_field_periods=100000) - CLI / environment knob:
NEO_JAX_MAX_RATIONAL_FIELD_PERIODS=100000 - controlled approximate fallback:
rational_surface_policy="approximate" - full exact legacy behavior, even if very slow:
max_rational_field_periods=0
Approximate mode keeps the base integration and skips only the expensive rational-surface correction once the preflight estimate exceeds the configured limit. The returned diagnostics record that approximation was used.
To avoid this regime in practice:
- avoid surfaces with
|iota|very close to zero - loosen
acc_reqif exact rational-surface resolution is not required - reduce the surface set when exploring a new equilibrium
- inspect the reported preflight estimate before launching dense runs
To force the exact long run:
config = NeoConfig(
surfaces=[...],
max_rational_field_periods=0,
)
or
export NEO_JAX_MAX_RATIONAL_FIELD_PERIODS=0
For CLI or environment-driven runs, the same policy can be selected with:
export NEO_JAX_RATIONAL_SURFACE_POLICY=approximate
Previously validated parity cases such as ORBITS, NCSX, and LandremanPaul2021_QA_lowres remain on the unchanged path unless the preflight estimate exceeds the configured limit.
Simple Python API
from neo_jax import NeoConfig, run_neo
# Surfaces may be specified by index or by s in [0, 1].
config = NeoConfig(surfaces=[0.15, 0.35, 0.6, 0.85], theta_n=64, phi_n=64)
results = run_neo("boozmn.nc", config=config)
# Access by name
print(results.epsilon_effective)
print(results["epsilon_effective_by_class"].shape)
JAX-native pipeline
You can run directly on JAX-native Boozer outputs (for example from
booz_xform_jax.jax_api) without writing boozmn files:
from neo_jax import NeoConfig, run_neo
# booz_out is a dict with keys like rmnc_b, zmns_b, pmns_b, bmnc_b, ixm_b, ixn_b
results = run_neo(booz_out, config=NeoConfig(surfaces=[1, 2, 3]))
For a full vmec_jax → booz_xform_jax → neo_jax workflow (no file I/O), use:
from neo_jax import NeoConfig, run_vmec_boozer_neo
config = NeoConfig(surfaces=[0.25, 0.5, 0.75], theta_n=32, phi_n=32)
results = run_vmec_boozer_neo(
"path/to/input.vmec",
vmec_kwargs=dict(max_iter=1, use_initial_guess=True, vmec_project=False),
booz_kwargs=dict(mboz=8, nboz=8),
neo_config=config,
)
For a JAX-native VMEC→Boozer adapter plus a JAX surface scan, use
run_vmec_boozer_neo_jax on a vmec_jax.FixedBoundaryRun object.
When using JAX surface scans, the return type is a JAX-friendly
NeoOutputs. Convert it to the standard NeoResults container with:
from neo_jax import neo_outputs_to_results
results = neo_outputs_to_results(outputs)
If you want a reusable, JIT-friendly pipeline callable (useful for loops and
optimizers), use build_vmec_boozer_neo_jax:
from neo_jax import build_vmec_boozer_neo_jax, NeoConfig
solver = build_vmec_boozer_neo_jax(run, booz_kwargs=dict(mboz=8, nboz=8),
neo_config=NeoConfig(surfaces=[0.5]), jit=True)
outputs = solver(run.state)
Documentation
Sphinx documentation lives in docs/ and is configured for Read the Docs.
See docs/index.rst for the table of contents.
Examples
examples/ncsx_epsilon_effective_plot.py: compute and plot epsilon effective vss.examples/ncsx_autodiff_Rmajor_optimization.py: autodiff optimization demo overRmajor.examples/epsilon_effective_scale_optimization.py: toy autodiff example that scales |B| to reduce epsilon effective.examples/qh_epsilon_effective_aspect_optimization.py: QH warm-start optimization (epsilon effective + aspect ratio).examples/vmec_boozer_neo_pipeline.py: full vmec_jax → booz_xform_jax → neo_jax pipeline.
NCSX Parity Snapshot
Legacy CLI Benchmark Snapshot
Measured on this workstation with /usr/bin/time -l using the STELLOPT
reference binary ~/bin/xneo and python -m neo_jax --quiet so the table
reflects solver cost rather than terminal logging overhead.
| Case | Parity status | xneo runtime (s) |
neo_jax runtime (s) |
Runtime ratio | xneo max RSS (MiB) |
neo_jax max RSS (MiB) |
Memory ratio |
|---|---|---|---|---|---|---|---|
LandremanPaul2021_QA_lowres |
Pass | 2.24 | 16.15 | 7.21x | 25.0 | 1471.8 | 58.83x |
ORBITS_MINI |
Pass | 0.05 | 18.66 | 373.20x | 14.6 | 1285.9 | 87.83x |
ORBITS_CURINT |
Pass | 0.33 | 6.67 | 20.21x | 14.6 | 624.4 | 42.70x |
ORBITS_FAST |
Pass (neo_out / neolog / conver[:4]) |
0.10 | 8.61 | 86.10x | 15.0 | 756.6 | 50.33x |
NCSX_MINI |
Pass | 0.06 | 5.62 | 93.67x | 35.0 | 581.8 | 16.64x |
ncsx_c09r00_free_fast |
Pass at rtol≈5e-3 |
2.30 | 12.96 | 5.63x | 38.0 | 1307.3 | 34.43x |
Notes:
ncsx_c09r00_free_fastnow passes the slow CLI regression at aboutrtol=5e-3(implemented as5.1e-3to account for the rounded legacy text output). The current last-surfaceepstotvalues are0.7159689869E-03(xneo) vs0.7123614767E-03(neo_jax).ORBITS_FASTis now practical again in legacy mode becauseWRITE_INTEGRATE=1uses the JAX solver plus a convergence callback in the default path instead of forcing the full Python-loop backend.- The dense
ORBITS_FASTregression now checksconver.datcolumns 1-4 in CI and exposesNEO_JAX_WRITE_IPMAX_DEBUG=1for step-by-step parity debugging of the remaining fifth-column discrepancy.
GPU Validation Snapshot
Validated on office (Pop!_OS, 2x NVIDIA RTX A4000, JAX 0.6.2) with:
env NEO_JAX_RUN_GPU=1 JAX_PLATFORM_NAME=gpu python -m pytest -q \
tests/regression/test_gpu_smoke.py
That GPU smoke suite checks:
- legacy CLI output parity between CPU and GPU on a one-surface ORBITS case
- Python API parity between CPU and GPU for
run_neo(..., use_jax=True, jax_surface_scan=True) - progress logging includes the active JAX runtime so GPU runs do not look hung
- the user-facing
examples/ncsx_epsilon_effective_plot.pyscript was also run on the same GPU host withMPLBACKEND=Aggand completed successfully while writingexamples/ncsx_eps_eff_vs_s.png
Cold-run timing on the same office host:
| Path | Case | CPU runtime (s) | GPU runtime (s) | CPU max RSS (MiB) | GPU max RSS (MiB) | Notes |
|---|---|---|---|---|---|---|
| Legacy CLI | LandremanPaul2021_QA_lowres |
39.41 | 95.71 | 1908.4 | 1966.4 | Cold launch, includes JIT compile |
| Python API | ORBITS single-surface smoke | 15.56 first / 8.99 reuse | 25.37 first / 14.06 reuse | n/a | n/a | run_neo(..., jax_surface_scan=True) |
At the current problem sizes, the GPU path is functional and parity-checked, but still compile-bound. On these small and medium legacy solves it is not yet faster than the CPU path. The GPU backend is still important for the larger JIT-native VMEC→Boozer→NEO workflows, where batching and reuse matter more than single-shot CLI latency.
| Metric | NEO (Fortran) | NEO_JAX (JAX) | Notes |
|---|---|---|---|
| Epsilon effective parity (max rel error, epstot) | — | 2.5e-10 | vs tests/fixtures/ncsx/neo_out.ncsx_c09r00_free |
| Runtime (10 surfaces, NCSX) | 60.37 s | 51.37 s | JAX time is steady-state after warmup |
| Max RSS (NCSX run) | 72.8 MiB | 4.45 GB | Measured via /usr/bin/time -l |
Repro commands:
# Fortran runtime + memory
/usr/bin/time -l /Users/rogerio/local/STELLOPT/NEO/Release/xneo ncsx_c09r00_free
# JAX runtime (steady-state) + memory
/usr/bin/time -l env PYTHONPATH=/Users/rogerio/local/tests/NEO_JAX \
python /Users/rogerio/local/tests/NEO_JAX/benchmarks/benchmark_ncsx.py --jax --warmup
Performance Tuning
NEO_JAX supports two Fourier evaluation modes:
NEO_JAX_FOURIER_MODE=vectorized(default): fastest but allocates theta×phi×mode temporaries.NEO_JAX_FOURIER_MODE=streamed: lower memory by streaming over modes; slightly slower.
NCSX benchmark comparison (10 surfaces, CPU warmup run, /usr/bin/time -l):
| Mode | Total time | Max RSS |
|---|---|---|
| Vectorized | 51.37 s | 4.45 GB |
| Streamed | 58.78 s | 2.55 GB |
Precision
NEO_JAX enables 64-bit JAX precision by default to match the Fortran reference outputs. You can override this behavior by setting either:
NEO_JAX_ENABLE_X64=0(NEO_JAX-specific)JAX_ENABLE_X64=0(global JAX)
Status
This repository is under active development. See PLAN.md for the porting
plan and roadmap.
Validation Cases
Current parity fixtures include:
- ORBITS (fast + full)
- NCSX tutorial case (fast by default; full gated by
NEO_JAX_RUN_SLOW=1) - LandremanPaul2021_QA_lowres (dense, 64x64 grid)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file neo_jax-1.0.1.tar.gz.
File metadata
- Download URL: neo_jax-1.0.1.tar.gz
- Upload date:
- Size: 61.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
533cc767b294634a19e938e4491c196711ad27de1d19e55baa9b67550d58bf4a
|
|
| MD5 |
b057f112b884bf19bfb654198341e60b
|
|
| BLAKE2b-256 |
87e53e7827e598ebaf70565d516502d08f0d848500aaf3efbf558d01e5a196e2
|
Provenance
The following attestation bundles were made for neo_jax-1.0.1.tar.gz:
Publisher:
publish-pypi.yml on uwplasma/NEO_JAX
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
neo_jax-1.0.1.tar.gz -
Subject digest:
533cc767b294634a19e938e4491c196711ad27de1d19e55baa9b67550d58bf4a - Sigstore transparency entry: 1339332079
- Sigstore integration time:
-
Permalink:
uwplasma/NEO_JAX@f5690a91364e9657333de7a5b369d14d915ef3f1 -
Branch / Tag:
refs/tags/v1.0.1 - Owner: https://github.com/uwplasma
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-pypi.yml@f5690a91364e9657333de7a5b369d14d915ef3f1 -
Trigger Event:
release
-
Statement type:
File details
Details for the file neo_jax-1.0.1-py3-none-any.whl.
File metadata
- Download URL: neo_jax-1.0.1-py3-none-any.whl
- Upload date:
- Size: 61.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0dbc9b620a49251bef07f8b638eae9ced8eca34b29a03d1fa792aaeac5cdab83
|
|
| MD5 |
214c9e806553453d1213d4970ea4b18d
|
|
| BLAKE2b-256 |
68119fa732775e3a1a4763b131465b83ba8d77385f2e61959e9dcfe72afacfbf
|
Provenance
The following attestation bundles were made for neo_jax-1.0.1-py3-none-any.whl:
Publisher:
publish-pypi.yml on uwplasma/NEO_JAX
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
neo_jax-1.0.1-py3-none-any.whl -
Subject digest:
0dbc9b620a49251bef07f8b638eae9ced8eca34b29a03d1fa792aaeac5cdab83 - Sigstore transparency entry: 1339332080
- Sigstore integration time:
-
Permalink:
uwplasma/NEO_JAX@f5690a91364e9657333de7a5b369d14d915ef3f1 -
Branch / Tag:
refs/tags/v1.0.1 - Owner: https://github.com/uwplasma
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-pypi.yml@f5690a91364e9657333de7a5b369d14d915ef3f1 -
Trigger Event:
release
-
Statement type: