JAX-native spherical harmonic transforms (BK-regime first): GPU-capable, differentiable, dependency-controlled.
Project description
jht — JAX Harmonic Transforms
JAX-native spherical harmonic transforms (map ↔ aₗₘ): GPU-capable, fully differentiable, and dependency-controlled (pure JAX + numpy at runtime — no compiled C++ extension, no heavyweight third-party SHT library). Scoped to the BICEP/Keck regime — spin-0 and spin-2 on the HEALPix RING pixelization, ℓ_max ≲ 1000, nside ≤ ~2048 — but written cleanly so it can serve as a general transform dependency.
It exists to serve the GPU / differentiable tier of analysis that a CPU-only C++
transform (ducc0) structurally cannot, while owning the numerics. See
docs/motivation.md for the full decision record.
Status (2026-06-10)
Phases 0–4 complete and validated (190 tests pass + 8 GPU-gated skips, CPU/float64):
- On-grid transforms — spin-0 & spin-2 synthesis (
aₗₘ→map) and the exact adjointSᵀ, validated to machine precision vs healpy and ducc0; spin-2 inverse at the HEALPix floor with no s2fft-style structural defect. - Accuracy — jht's own ring quadrature weights + Jacobi iteration reach
~1e-13 on band-limited maps (matches
healpy.map2alm(use_weights=True)); seedocs/accuracy.md. - Partial-sky — masked pseudo-aₗₘ, a cut-sky CG deconvolution, and a masked
Wiener filter / constrained realization (the MUSE inner solve); see
docs/masked.md. - Off-grid (NUFFT) —
synthesis_general/adjoint_synthesis_generalevaluate a band-limited field at arbitrary pointings (spin 0–3), alm- and pointing-differentiable. The JAX-native replacement for ducc0'ssht.experimental.synthesis_general(on-grid SHT + this NUFFT = the full ducc0 surface bk-jax depends on); seedocs/offgrid.md. - Differentiability — native JAX autodiff (
jacfwd ≡ jacrev, tight adjoint identity), plus a convention-clean real-DOF layerjht.diff; seedocs/design.md§Differentiability. - GPU — pure JAX, so the transforms run on CUDA with no code change. Measured on
Cannon A100/V100 (fp64): GPU==CPU parity ~1e-13 across the BK regime including
nside=2048, forward synthesis 14–60× the CPU. See Performance below and
docs/gpu.md.
Install
Standard env is pixi:
pixi install # CPU env (osx-arm64 / linux-64)
pixi run test # the gated suite
GPU (CUDA, linux-64 — see docs/gpu.md):
pixi run -e gpu python scripts/gpu_check.py # on an NVIDIA box
As a dependency in another project (runtime deps are just jax + numpy):
pip install jaxht # once released on PyPI — then `import jht`
# or track the repo directly:
pip install "jaxht @ git+https://github.com/jrcheshire/jht.git"
Quick start
float64 is opt-in per entry point — enable it before creating any array (library code never touches the global config):
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jht
nside, lmax, spin = 256, 512, 0
m = jht.synthesis(alm, nside, lmax, spin=spin) # aₗₘ -> map
a = jht.map2alm(m, nside, lmax, spin=spin, niter=3) # map -> aₗₘ (weighted + iterated)
cl = jht.bandpower(a, lmax, spin=spin) # angular auto-power C_ℓ
spin=2 takes/returns (E, B) aₗₘ of shape (2, …) and (Q, U) maps of shape
(2, npix). jht.adjoint_synthesis is the exact unweighted transpose Sᵀ
(the operator seam / VJP), distinct from map2alm (the approximate inverse). For
gradient-based work use the real-DOF layer jht.synthesis_real /
jht.analysis_real (plain ℝⁿ→ℝᵐ, no complex-conjugate convention subtlety).
Conventions
healpy m-major triangular aₗₘ packing, orthonormal Yₗₘ with the Condon–Shortley
phase, HEALPix-internal (COSMO) polarization — verified against healpy 1.19.0 and
ducc0 0.41.0. Pinned in docs/design.md.
Accuracy tiers (the contract)
jht targets the GPU/differentiable tier where the HEALPix ~1e-3 sampling
floor is acceptable; weights + iteration close it to ~1e-13 on band-limited
inputs. It is not a drop-in for ducc's purity-critical (~1e-4 E→B-leakage)
production path. Tolerances are a-priori and gate-driven, never relaxed without
sign-off. Residual mismatches are logged in
DISCREPANCIES.md.
Performance
Pure JAX runs unchanged on CUDA. Measured on Cannon A100 (incl. a 20 GB MIG) / V100, fp64:
- GPU==CPU parity ~1e-13 across the BK regime, including nside=2048 (synthesis and
map2alm). - Forward synthesis 14–60× the 8-core CPU; fp64/fp32 ≈ 2.2×.
- Off-grid forward ~0.5–0.9 s at ℓ_max=1000 (independent of the number of points; recursion-bound), with the pointing gradient ~1× a forward.
- nside=2048 compiles and runs on GPU — a ~20 GB slice holds synthesis +
map2alm; the one-time compile is multi-minute (jit-cached).
The recurring GPU lesson: fp64/complex scatters are catastrophic on GPU, so jht packs and assembles via gathers. CPU perf model + memory in docs/performance.md; GPU detail in docs/gpu.md.
Using jht as a dependency
jht is standalone and consumer-agnostic. The operator/grad seam a downstream
needs (e.g. to use jht in place of ducc0) — and the accuracy boundary — are
documented in docs/consumers.md. Any backend-selection
wiring lives in the consumer, not here.
Docs
docs/design.md— technical design, conventions, the crux, differentiability.docs/accuracy.md— the accuracy contract + ring-weight algorithm.docs/masked.md— partial-sky estimators.docs/performance.md— CPU perf model + memory.docs/gpu.md— the GPU env, the x64 requirement, the harness.docs/consumers.md— the downstream-dependency seam.docs/motivation.md— why jht exists.ROADMAP.md— phased plan + gates.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jaxht-0.1.0.tar.gz.
File metadata
- Download URL: jaxht-0.1.0.tar.gz
- Upload date:
- Size: 160.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
06617f8e699ef71271959353312107371ba2a3c5ec380d2de7945a2d63beab38
|
|
| MD5 |
1dcdf53ccfba4513c733113ba78c52bf
|
|
| BLAKE2b-256 |
e9655afb327e86a5ec834e9239ae188007691b159598809005b140557c086cc6
|
Provenance
The following attestation bundles were made for jaxht-0.1.0.tar.gz:
Publisher:
release.yml on jrcheshire/jht
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jaxht-0.1.0.tar.gz -
Subject digest:
06617f8e699ef71271959353312107371ba2a3c5ec380d2de7945a2d63beab38 - Sigstore transparency entry: 1775786491
- Sigstore integration time:
-
Permalink:
jrcheshire/jht@3787b140161d1c64d4ee9b03d2d9a0c2a68f8ac6 -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/jrcheshire
-
Access:
private
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@3787b140161d1c64d4ee9b03d2d9a0c2a68f8ac6 -
Trigger Event:
push
-
Statement type:
File details
Details for the file jaxht-0.1.0-py3-none-any.whl.
File metadata
- Download URL: jaxht-0.1.0-py3-none-any.whl
- Upload date:
- Size: 39.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b288d95ee4c5f0286183d805a42bf4d54e54f862558dde3d533af6c34c55fe4d
|
|
| MD5 |
2a049e4750d7fb4372c3da178ca2588e
|
|
| BLAKE2b-256 |
42e911c9b5dbf3b559f8861679ed7263181149e24a3152c3bf038fe7fadaed78
|
Provenance
The following attestation bundles were made for jaxht-0.1.0-py3-none-any.whl:
Publisher:
release.yml on jrcheshire/jht
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jaxht-0.1.0-py3-none-any.whl -
Subject digest:
b288d95ee4c5f0286183d805a42bf4d54e54f862558dde3d533af6c34c55fe4d - Sigstore transparency entry: 1775786642
- Sigstore integration time:
-
Permalink:
jrcheshire/jht@3787b140161d1c64d4ee9b03d2d9a0c2a68f8ac6 -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/jrcheshire
-
Access:
private
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@3787b140161d1c64d4ee9b03d2d9a0c2a68f8ac6 -
Trigger Event:
push
-
Statement type: