Skip to main content

JAX-native spherical harmonic transforms (BK-regime first): GPU-capable, differentiable, dependency-controlled.

Project description

jht — JAX Harmonic Transforms

JAX-native spherical harmonic transforms (map ↔ aₗₘ): GPU-capable, fully differentiable, and dependency-controlled (pure JAX + numpy at runtime — no compiled C++ extension, no heavyweight third-party SHT library). Scoped to the BICEP/Keck regime — spin-0 and spin-2 on the HEALPix RING pixelization, ℓ_max ≲ 1000, nside ≤ ~2048 — but written cleanly so it can serve as a general transform dependency.

It exists to serve the GPU / differentiable tier of analysis that a CPU-only C++ transform (ducc0) structurally cannot, while owning the numerics. See docs/motivation.md for the full decision record.

Status (2026-06-10)

Phases 0–4 complete and validated (190 tests pass + 8 GPU-gated skips, CPU/float64):

  • On-grid transforms — spin-0 & spin-2 synthesis (aₗₘ→map) and the exact adjoint Sᵀ, validated to machine precision vs healpy and ducc0; spin-2 inverse at the HEALPix floor with no s2fft-style structural defect.
  • Accuracy — jht's own ring quadrature weights + Jacobi iteration reach ~1e-13 on band-limited maps (matches healpy.map2alm(use_weights=True)); see docs/accuracy.md.
  • Partial-sky — masked pseudo-aₗₘ, a cut-sky CG deconvolution, and a masked Wiener filter / constrained realization (the MUSE inner solve); see docs/masked.md.
  • Off-grid (NUFFT)synthesis_general / adjoint_synthesis_general evaluate a band-limited field at arbitrary pointings (spin 0–3), alm- and pointing-differentiable. The JAX-native replacement for ducc0's sht.experimental.synthesis_general (on-grid SHT + this NUFFT = the full ducc0 surface bk-jax depends on); see docs/offgrid.md.
  • Differentiability — native JAX autodiff (jacfwd ≡ jacrev, tight adjoint identity), plus a convention-clean real-DOF layer jht.diff; see docs/design.md §Differentiability.
  • GPU — pure JAX, so the transforms run on CUDA with no code change. Measured on Cannon A100/V100 (fp64): GPU==CPU parity ~1e-13 across the BK regime including nside=2048, forward synthesis 14–60× the CPU. See Performance below and docs/gpu.md.

Install

Standard env is pixi:

pixi install          # CPU env (osx-arm64 / linux-64)
pixi run test         # the gated suite

GPU (CUDA, linux-64 — see docs/gpu.md):

pixi run -e gpu python scripts/gpu_check.py   # on an NVIDIA box

As a dependency in another project (runtime deps are just jax + numpy):

pip install jaxht        # once released on PyPI — then `import jht`
# or track the repo directly:
pip install "jaxht @ git+https://github.com/jrcheshire/jht.git"

Quick start

float64 is opt-in per entry point — enable it before creating any array (library code never touches the global config):

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jht

nside, lmax, spin = 256, 512, 0
m = jht.synthesis(alm, nside, lmax, spin=spin)          # aₗₘ -> map
a = jht.map2alm(m, nside, lmax, spin=spin, niter=3)     # map -> aₗₘ (weighted + iterated)
cl = jht.bandpower(a, lmax, spin=spin)                  # angular auto-power C_ℓ

spin=2 takes/returns (E, B) aₗₘ of shape (2, …) and (Q, U) maps of shape (2, npix). jht.adjoint_synthesis is the exact unweighted transpose Sᵀ (the operator seam / VJP), distinct from map2alm (the approximate inverse). For gradient-based work use the real-DOF layer jht.synthesis_real / jht.analysis_real (plain ℝⁿ→ℝᵐ, no complex-conjugate convention subtlety).

Conventions

healpy m-major triangular aₗₘ packing, orthonormal Yₗₘ with the Condon–Shortley phase, HEALPix-internal (COSMO) polarization — verified against healpy 1.19.0 and ducc0 0.41.0. Pinned in docs/design.md.

Accuracy tiers (the contract)

jht targets the GPU/differentiable tier where the HEALPix ~1e-3 sampling floor is acceptable; weights + iteration close it to ~1e-13 on band-limited inputs. It is not a drop-in for ducc's purity-critical (~1e-4 E→B-leakage) production path. Tolerances are a-priori and gate-driven, never relaxed without sign-off. Residual mismatches are logged in DISCREPANCIES.md.

Performance

Pure JAX runs unchanged on CUDA. Measured on Cannon A100 (incl. a 20 GB MIG) / V100, fp64:

  • GPU==CPU parity ~1e-13 across the BK regime, including nside=2048 (synthesis and map2alm).
  • Forward synthesis 14–60× the 8-core CPU; fp64/fp32 ≈ 2.2×.
  • Off-grid forward ~0.5–0.9 s at ℓ_max=1000 (independent of the number of points; recursion-bound), with the pointing gradient ~1× a forward.
  • nside=2048 compiles and runs on GPU — a ~20 GB slice holds synthesis + map2alm; the one-time compile is multi-minute (jit-cached).

The recurring GPU lesson: fp64/complex scatters are catastrophic on GPU, so jht packs and assembles via gathers. CPU perf model + memory in docs/performance.md; GPU detail in docs/gpu.md.

Using jht as a dependency

jht is standalone and consumer-agnostic. The operator/grad seam a downstream needs (e.g. to use jht in place of ducc0) — and the accuracy boundary — are documented in docs/consumers.md. Any backend-selection wiring lives in the consumer, not here.

Docs

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxht-0.1.0.tar.gz (160.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxht-0.1.0-py3-none-any.whl (39.2 kB view details)

Uploaded Python 3

File details

Details for the file jaxht-0.1.0.tar.gz.

File metadata

  • Download URL: jaxht-0.1.0.tar.gz
  • Upload date:
  • Size: 160.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jaxht-0.1.0.tar.gz
Algorithm Hash digest
SHA256 06617f8e699ef71271959353312107371ba2a3c5ec380d2de7945a2d63beab38
MD5 1dcdf53ccfba4513c733113ba78c52bf
BLAKE2b-256 e9655afb327e86a5ec834e9239ae188007691b159598809005b140557c086cc6

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxht-0.1.0.tar.gz:

Publisher: release.yml on jrcheshire/jht

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxht-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: jaxht-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 39.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jaxht-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b288d95ee4c5f0286183d805a42bf4d54e54f862558dde3d533af6c34c55fe4d
MD5 2a049e4750d7fb4372c3da178ca2588e
BLAKE2b-256 42e911c9b5dbf3b559f8861679ed7263181149e24a3152c3bf038fe7fadaed78

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxht-0.1.0-py3-none-any.whl:

Publisher: release.yml on jrcheshire/jht

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page