Skip to main content

Neoclassical transport code for effective ripple and trapped-particle diagnostics in Boozer coordinates

Project description

NEO_JAX

Install from PyPI:

pip install neo-jax

Neoclassical transport code for computing effective helical ripple and related trapped-particle diagnostics from Boozer-coordinate geometry, with a Python API, JAX acceleration, and an end-to-end differentiable pipeline through VMEC and Boozer transforms.

The full documentation lives in docs/ and covers the physics model, numerics, geometry interfaces, runtime controls, applications, testing, and source structure.

Quick start

pip install neo-jax

For development from a clone:

pip install -e .
neo-jax ORBITS --boozmn tests/fixtures/orbits/boozmn_ORBITS.nc --verbose

The package also installs legacy-compatible entrypoints:

xneo
xneo_jax
python -m neo_jax

If you prefer not to install, run from the repo root with:

PYTHONPATH=. python examples/ncsx_epsilon_effective_plot.py

CLI and file-based workflow

neo_jax now supports the same terminal workflow as STELLOPT's xneo, including both the effective-ripple solve (calc_cur = 0) and the parallel-current path (calc_cur = 1). The intent is simple: if you already have a working NEO input deck, you should be able to point the same control file at neo_jax and keep the same filenames, CLI invocation, and output-file layout.

Legacy invocation:

xneo ORBITS
xneo_jax ORBITS
python -m neo_jax ORBITS

Unlike the STELLOPT binary, neo_jax prints explicit progress messages by default so a long parity run does not look stalled. The CLI reports the control file, Boozer file, solve mode, surface count, selected backend, and the active JAX runtime, for example:

NEO_JAX: starting legacy CLI solve
NEO_JAX: control=neo_in.ORBITS
NEO_JAX: boozmn=boozmn_ORBITS.nc
NEO_JAX: mode=calc_cur=0 (epsilon effective)
NEO_JAX: surfaces=10 theta_n=64 phi_n=64 npart=40 backend=JAX
NEO_JAX: jax_runtime=gpu (2 devices: NVIDIA RTX A4000, NVIDIA RTX A4000)

Use --quiet if you want the legacy file outputs without the extra terminal logging.

The progress output is now more detailed for both the CLI and the Python API:

  • surface number and resolved flux index
  • s, sqrt(s), and iota
  • resolution (theta_n, phi_n, npart, multra, nstep_*)
  • Boozer geometry summary (nfp, number of modes, B00, Bmin, Bmax)
  • a preflight estimate of rational-surface work

Control-file lookup follows the same search order as STELLOPT:

  1. neo_param.<extension>
  2. neo_param.in
  3. neo_in.<extension>
  4. if no extension is given: neo.in

For inp_swi = 0, the CLI follows the legacy boozmn_<extension>.nc convention used by STELLOPT. For nonzero inp_swi, it resolves the input from IN_FILE in the control file.

What the legacy CLI writes:

  • main output: neo_out.*
  • log file: neolog.*
  • parallel-current summary when calc_cur = 1: neo_cur.*
  • optional parallel-current integration history: current.dat
  • optional diagnostic files: diagnostic.dat, diagnostic_add.dat, diagnostic_bigint.dat
  • optional integration history: conver.dat
  • optional geometry dumps: dimension.dat, theta_arr.dat, phi_arr.dat, rmnc_arr.dat, bmnc_arr.dat, b_s_arr.dat, and the other legacy array files controlled by WRITE_OUTPUT_FILES

Compatibility scope:

  • supported: calc_cur = 0 and calc_cur = 1
  • control-file lookup parity for neo_param.<ext>, neo_param.in, and neo_in.<ext>
  • legacy Boozer filename resolution via boozmn_<extension>.nc
  • default CLI progress logging, with --quiet available for benchmarking or silent batch runs
  • runtime backend reporting (cpu, gpu, etc.) so long JAX compiles are clearly visible in terminal output

How the compatibility layer is implemented:

  • neo_jax/cli.py mirrors the legacy command-line contract and control-file search logic.
  • neo_jax/legacy.py reproduces the Fortran text formatting used in neo_out.*, neo_cur.*, neolog.*, diagnostic*.dat, current.dat, and conver.dat.
  • neo_jax/driver.py writes the same auxiliary files that STELLOPT writes when WRITE_OUTPUT_FILES, WRITE_INTEGRATE, WRITE_DIAGNOSTIC, or WRITE_CUR_INTE are enabled.
  • NEO_JAX_WRITE_IPMAX_DEBUG=1 writes diagnostic_ipmax_jax.dat, a per-step dump of the trapped-event amplitude used while building conver.dat.
  • neo_jax/current.py ports the legacy flint_cur / calccur path to JAX so neo_cur.* and current.dat are produced by the same CLI entrypoint.
  • The solver still calls the same JAX/Python backend used by the public API, so the terminal interface and Python interface stay numerically aligned.

How it is tested:

  • tests/regression/test_cli_legacy.py runs the JAX CLI against committed reference outputs that were generated once with the STELLOPT xneo executable, so parity is checked in CI without requiring that external binary.
  • tests/regression/test_gpu_smoke.py adds optional CPU-vs-GPU smoke coverage for both the legacy CLI and the public Python API when NEO_JAX_RUN_GPU=1 is set on a machine with a visible JAX GPU backend.
  • The test suite covers:
    • a real dense fixture: LandremanPaul2021_QA_lowres
    • a synthetic one-surface ORBITS legacy case that exercises neo_out, neolog, diagnostic*.dat, conver.dat, and all legacy array dumps
    • a one-surface ORBITS calc_cur = 1 case that checks neo_cur.* exactly and current.dat token-by-token against the STELLOPT executable
    • control-file precedence checks for neo_param.<extension> and neo_param.in
    • optional slow full-fixture parity checks for ORBITS_FAST and ncsx_c09r00_free_fast when NEO_JAX_RUN_SLOW=1

Current comparison status:

  • neo_out.*, neo_cur.*, diagnostic*.dat, and neolog.* match the stored xneo reference text output exactly on the legacy parity cases.
  • conver.dat matches exactly on the synthetic ORBITS parity case. On the full ORBITS_FAST fixture, columns 1-4 match exactly; the fifth column is traced with diagnostic_ipmax_jax.dat because the STELLOPT binary's text output on that dense case does not follow its own traced aditot / p_bm2 state.
  • current.dat matches token-by-token, including NaN/Infinity placement, with tight floating-point tolerances to account for backend-level roundoff in the intermediate current history.
  • The legacy array dumps (*_arr.dat, dimension.dat, theta_arr.dat, phi_arr.dat) are numerically identical to within floating-point roundoff.
  • GPU execution is now validated separately on the office workstation for both python -m neo_jax and the Python API; see the GPU table below.

Low-|iota| Surfaces

For surfaces with very small rotational transform |iota|, the legacy NEO rational-surface correction can become extremely expensive:

nfp_rat ~= ceil(1 / acc_req / |iota|)

Physically, very small |iota| means a field line can require many field periods to sample the rational structure. Numerically, that makes the legacy correction cost explode and a run can look stalled even though it is still making progress.

NEO_JAX therefore preflights the estimated rational-correction workload:

  • default guarded behavior: NeoConfig(max_rational_field_periods=100000)
  • CLI / environment knob: NEO_JAX_MAX_RATIONAL_FIELD_PERIODS=100000
  • controlled approximate fallback: rational_surface_policy="approximate"
  • full exact legacy behavior, even if very slow: max_rational_field_periods=0

Approximate mode keeps the base integration and skips only the expensive rational-surface correction once the preflight estimate exceeds the configured limit. The returned diagnostics record that approximation was used.

To avoid this regime in practice:

  • avoid surfaces with |iota| very close to zero
  • loosen acc_req if exact rational-surface resolution is not required
  • reduce the surface set when exploring a new equilibrium
  • inspect the reported preflight estimate before launching dense runs

To force the exact long run:

config = NeoConfig(
    surfaces=[...],
    max_rational_field_periods=0,
)

or

export NEO_JAX_MAX_RATIONAL_FIELD_PERIODS=0

For CLI or environment-driven runs, the same policy can be selected with:

export NEO_JAX_RATIONAL_SURFACE_POLICY=approximate

Previously validated parity cases such as ORBITS, NCSX, and LandremanPaul2021_QA_lowres remain on the unchanged path unless the preflight estimate exceeds the configured limit.

Simple Python API

from neo_jax import NeoConfig, run_neo

# Surfaces may be specified by index or by s in [0, 1].
config = NeoConfig(surfaces=[0.15, 0.35, 0.6, 0.85], theta_n=64, phi_n=64)
results = run_neo("boozmn.nc", config=config)

# Access by name
print(results.epsilon_effective)
print(results["epsilon_effective_by_class"].shape)

JAX-native pipeline

You can run directly on JAX-native Boozer outputs (for example from booz_xform_jax.jax_api) without writing boozmn files:

from neo_jax import NeoConfig, run_neo

# booz_out is a dict with keys like rmnc_b, zmns_b, pmns_b, bmnc_b, ixm_b, ixn_b
results = run_neo(booz_out, config=NeoConfig(surfaces=[1, 2, 3]))

For a full vmec_jax → booz_xform_jax → neo_jax workflow (no file I/O), use:

from neo_jax import NeoConfig, run_vmec_boozer_neo

config = NeoConfig(surfaces=[0.25, 0.5, 0.75], theta_n=32, phi_n=32)
results = run_vmec_boozer_neo(
    "path/to/input.vmec",
    vmec_kwargs=dict(max_iter=1, use_initial_guess=True, vmec_project=False),
    booz_kwargs=dict(mboz=8, nboz=8),
    neo_config=config,
)

For a JAX-native VMEC→Boozer adapter plus a JAX surface scan, use run_vmec_boozer_neo_jax on a vmec_jax.FixedBoundaryRun object.

When using JAX surface scans, the return type is a JAX-friendly NeoOutputs. Convert it to the standard NeoResults container with:

from neo_jax import neo_outputs_to_results

results = neo_outputs_to_results(outputs)

If you want a reusable, JIT-friendly pipeline callable (useful for loops and optimizers), use build_vmec_boozer_neo_jax:

from neo_jax import build_vmec_boozer_neo_jax, NeoConfig

solver = build_vmec_boozer_neo_jax(run, booz_kwargs=dict(mboz=8, nboz=8),
                                   neo_config=NeoConfig(surfaces=[0.5]), jit=True)
outputs = solver(run.state)

Documentation

Sphinx documentation lives in docs/ and is configured for Read the Docs. See docs/index.rst for the table of contents.

Examples

  • examples/ncsx_epsilon_effective_plot.py: compute and plot epsilon effective vs s.
  • examples/ncsx_autodiff_Rmajor_optimization.py: autodiff optimization demo over Rmajor.
  • examples/epsilon_effective_scale_optimization.py: toy autodiff example that scales |B| to reduce epsilon effective.
  • examples/qh_epsilon_effective_aspect_optimization.py: QH warm-start optimization (epsilon effective + aspect ratio).
  • examples/vmec_boozer_neo_pipeline.py: full vmec_jax → booz_xform_jax → neo_jax pipeline.

NCSX Parity Snapshot

NCSX epstot parity

Legacy CLI Benchmark Snapshot

Measured on this workstation with /usr/bin/time -l using the STELLOPT reference binary ~/bin/xneo and python -m neo_jax --quiet so the table reflects solver cost rather than terminal logging overhead.

Case Parity status xneo runtime (s) neo_jax runtime (s) Runtime ratio xneo max RSS (MiB) neo_jax max RSS (MiB) Memory ratio
LandremanPaul2021_QA_lowres Pass 2.24 16.15 7.21x 25.0 1471.8 58.83x
ORBITS_MINI Pass 0.05 18.66 373.20x 14.6 1285.9 87.83x
ORBITS_CURINT Pass 0.33 6.67 20.21x 14.6 624.4 42.70x
ORBITS_FAST Pass (neo_out / neolog / conver[:4]) 0.10 8.61 86.10x 15.0 756.6 50.33x
NCSX_MINI Pass 0.06 5.62 93.67x 35.0 581.8 16.64x
ncsx_c09r00_free_fast Pass at rtol≈5e-3 2.30 12.96 5.63x 38.0 1307.3 34.43x

Notes:

  • ncsx_c09r00_free_fast now passes the slow CLI regression at about rtol=5e-3 (implemented as 5.1e-3 to account for the rounded legacy text output). The current last-surface epstot values are 0.7159689869E-03 (xneo) vs 0.7123614767E-03 (neo_jax).
  • ORBITS_FAST is now practical again in legacy mode because WRITE_INTEGRATE=1 uses the JAX solver plus a convergence callback in the default path instead of forcing the full Python-loop backend.
  • The dense ORBITS_FAST regression now checks conver.dat columns 1-4 in CI and exposes NEO_JAX_WRITE_IPMAX_DEBUG=1 for step-by-step parity debugging of the remaining fifth-column discrepancy.

GPU Validation Snapshot

Validated on office (Pop!_OS, 2x NVIDIA RTX A4000, JAX 0.6.2) with:

env NEO_JAX_RUN_GPU=1 JAX_PLATFORM_NAME=gpu python -m pytest -q \
  tests/regression/test_gpu_smoke.py

That GPU smoke suite checks:

  • legacy CLI output parity between CPU and GPU on a one-surface ORBITS case
  • Python API parity between CPU and GPU for run_neo(..., use_jax=True, jax_surface_scan=True)
  • progress logging includes the active JAX runtime so GPU runs do not look hung
  • the user-facing examples/ncsx_epsilon_effective_plot.py script was also run on the same GPU host with MPLBACKEND=Agg and completed successfully while writing examples/ncsx_eps_eff_vs_s.png

Cold-run timing on the same office host:

Path Case CPU runtime (s) GPU runtime (s) CPU max RSS (MiB) GPU max RSS (MiB) Notes
Legacy CLI LandremanPaul2021_QA_lowres 39.41 95.71 1908.4 1966.4 Cold launch, includes JIT compile
Python API ORBITS single-surface smoke 15.56 first / 8.99 reuse 25.37 first / 14.06 reuse n/a n/a run_neo(..., jax_surface_scan=True)

At the current problem sizes, the GPU path is functional and parity-checked, but still compile-bound. On these small and medium legacy solves it is not yet faster than the CPU path. The GPU backend is still important for the larger JIT-native VMEC→Boozer→NEO workflows, where batching and reuse matter more than single-shot CLI latency.

Metric NEO (Fortran) NEO_JAX (JAX) Notes
Epsilon effective parity (max rel error, epstot) 2.5e-10 vs tests/fixtures/ncsx/neo_out.ncsx_c09r00_free
Runtime (10 surfaces, NCSX) 60.37 s 51.37 s JAX time is steady-state after warmup
Max RSS (NCSX run) 72.8 MiB 4.45 GB Measured via /usr/bin/time -l

Repro commands:

# Fortran runtime + memory
/usr/bin/time -l /Users/rogerio/local/STELLOPT/NEO/Release/xneo ncsx_c09r00_free

# JAX runtime (steady-state) + memory
/usr/bin/time -l env PYTHONPATH=/Users/rogerio/local/tests/NEO_JAX \
  python /Users/rogerio/local/tests/NEO_JAX/benchmarks/benchmark_ncsx.py --jax --warmup

Performance Tuning

NEO_JAX supports two Fourier evaluation modes:

  • NEO_JAX_FOURIER_MODE=vectorized (default): fastest but allocates theta×phi×mode temporaries.
  • NEO_JAX_FOURIER_MODE=streamed: lower memory by streaming over modes; slightly slower.

NCSX benchmark comparison (10 surfaces, CPU warmup run, /usr/bin/time -l):

Mode Total time Max RSS
Vectorized 51.37 s 4.45 GB
Streamed 58.78 s 2.55 GB

Precision

NEO_JAX enables 64-bit JAX precision by default to match the Fortran reference outputs. You can override this behavior by setting either:

  • NEO_JAX_ENABLE_X64=0 (NEO_JAX-specific)
  • JAX_ENABLE_X64=0 (global JAX)

Status

This repository is under active development. See PLAN.md for the porting plan and roadmap.

Validation Cases

Current parity fixtures include:

  • ORBITS (fast + full)
  • NCSX tutorial case (fast by default; full gated by NEO_JAX_RUN_SLOW=1)
  • LandremanPaul2021_QA_lowres (dense, 64x64 grid)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

neo_jax-1.0.1.tar.gz (61.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

neo_jax-1.0.1-py3-none-any.whl (61.1 kB view details)

Uploaded Python 3

File details

Details for the file neo_jax-1.0.1.tar.gz.

File metadata

  • Download URL: neo_jax-1.0.1.tar.gz
  • Upload date:
  • Size: 61.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for neo_jax-1.0.1.tar.gz
Algorithm Hash digest
SHA256 533cc767b294634a19e938e4491c196711ad27de1d19e55baa9b67550d58bf4a
MD5 b057f112b884bf19bfb654198341e60b
BLAKE2b-256 87e53e7827e598ebaf70565d516502d08f0d848500aaf3efbf558d01e5a196e2

See more details on using hashes here.

Provenance

The following attestation bundles were made for neo_jax-1.0.1.tar.gz:

Publisher: publish-pypi.yml on uwplasma/NEO_JAX

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file neo_jax-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: neo_jax-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 61.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for neo_jax-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 0dbc9b620a49251bef07f8b638eae9ced8eca34b29a03d1fa792aaeac5cdab83
MD5 214c9e806553453d1213d4970ea4b18d
BLAKE2b-256 68119fa732775e3a1a4763b131465b83ba8d77385f2e61959e9dcfe72afacfbf

See more details on using hashes here.

Provenance

The following attestation bundles were made for neo_jax-1.0.1-py3-none-any.whl:

Publisher: publish-pypi.yml on uwplasma/NEO_JAX

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page