Differentiable telescope pointing in JAX — ERFA, qpoint, and so3g.proj combined
Project description
so_pointjax
Differentiable telescope pointing in JAX. A pure-Python, GPU-ready reimplementation
of the ERFA/qpoint/so3g pointing stack, fully compatible with jax.jit, jax.grad,
and jax.vmap.
Architecture
so_pointjax
├── erfa Low-level ERFA routines (time, precession, nutation, astrometry, ...)
├── qpoint Mid-level pointing pipeline (az/el → RA/Dec, HEALPix, IERS)
└── proj High-level API (Quat, CelestialSightLine, FocalPlane, Assembly)
Each layer builds on the one below. Use proj for most telescope work;
reach into qpoint or erfa when you need finer control.
Installation
pip install so-pointjax # from PyPI (when published)
pip install -e . # editable install from source
Dependencies: jax >= 0.4.0, jaxlib >= 0.4.0. Tests additionally need
pytest >= 7.0 and pyerfa >= 2.0.
Quick start
import jax
import jax.numpy as jnp
from so_pointjax.proj import Quat, CelestialSightLine, FocalPlane
# Build a pointing sightline from az/el + timestamps
t = jnp.array([1700000000.0, 1700000001.0])
az = jnp.array([1.0, 1.01])
el = jnp.array([0.8, 0.8])
csl = CelestialSightLine.az_el(t, az, el, site='act', weather='toco')
# Extract sky coordinates: (N, 4) → [lon, lat, cos2psi, sin2psi]
coords = csl.coords()
# With a focal plane of detectors
fp = FocalPlane.from_xieta(
jnp.array([0.0, 0.01, -0.01]),
jnp.array([0.0, 0.01, 0.01]),
)
det_coords = csl.coords(fplane=fp) # (n_det, N, 4)
End-to-end differentiable pointing
The entire pipeline is differentiable. Compute gradients of sky coordinates with respect to any input:
def sky_lon(az, el):
csl = CelestialSightLine.naive_az_el(
jnp.array([1700000000.0]),
jnp.array([az]),
jnp.array([el]),
site='act',
)
return csl.coords()[0, 0] # RA of first sample
dra_daz, dra_del = jax.grad(sky_lon, argnums=(0, 1))(1.0, 0.8)
Quaternion algebra
The Quat class wraps JAX arrays with quaternion arithmetic, broadcasting,
and operator overloading:
from so_pointjax.proj import Quat
q1 = Quat.from_lonlat(1.0, 0.5)
q2 = Quat.from_euler(2, 0.1)
q = q1 * q2 # quaternion product
q_inv = ~q # conjugate/inverse
# Batch operations with automatic broadcasting
q_arr = Quat.from_euler(2, jnp.linspace(0, 1, 1000))
rotated = q1 * q_arr # (4,) x (1000, 4) -> (1000, 4)
Quat is a JAX pytree and works transparently with jit, grad, and vmap.
Submodule guides
Each submodule has its own detailed README:
-
so_pointjax.erfa-- Differentiable ERFA: time scales, precession-nutation, astrometry, ephemerides, coordinate frames, geodetic transforms, and more (~200 functions). -
so_pointjax.qpoint-- Pointing pipeline: quaternion algebra, atmospheric/aberration corrections, az/el to RA/Dec conversion, HEALPix pixelization, IERS Bulletin A support. -
so_pointjax.proj-- High-level API:Quatclass with operator overloading,CelestialSightLine,FocalPlane,Assembly, observatory sites, and weather models.
Precision
Validated against the original C/Fortran implementations to sub-milliarcsecond accuracy:
| Layer | Agreement |
|---|---|
| ERFA functions | Matches pyerfa |
| Quaternion functions | Bit-identical to so3g |
naive_az_el |
~1e-12 (quat diff) |
az_el (all weather/sites) |
0.0004--0.0005 arcsec |
| Detector projection | ~1e-12 |
Running tests
# All tests
python -m pytest so_pointjax/ -v
# By submodule
python -m pytest so_pointjax/erfa/tests/ -v
python -m pytest so_pointjax/qpoint/tests/ -v
python -m pytest so_pointjax/proj/tests/ -v
# Cross-validation against so3g (requires so3g)
python -m pytest so_pointjax/proj/tests/test_cross_validation.py -v -s
Performance
Key results (CPU, JIT-compiled):
- Quaternion ops: 1.3--7x faster than so3g at N >= 100K
- Pointing pipeline: 3--8x faster across all sizes
- Bore x det composition: 2--3x faster for realistic focal planes
- Gradients: ~1 us/sample (unique capability vs so3g)
python -m so_pointjax.proj.benchmarks.bench_so3g [--quick]
python -m so_pointjax.proj.benchmarks.bench_quat_array [--quick]
python -m so_pointjax.qpoint.benchmarks.bench_pointing [--quick]
Acknowledgments
This package is a JAX reimplementation of algorithms from:
- ERFA (NumFOCUS Foundation), derived from the IAU's SOFA library
- qpoint (Alexandra Rahlin)
- so3g (Simons Observatory)
- HEALPix (Gorski et al.)
See NOTICE for full license texts of the upstream libraries.
License
BSD-3-Clause. See LICENSE.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file so_pointjax-0.1.0.tar.gz.
File metadata
- Download URL: so_pointjax-0.1.0.tar.gz
- Upload date:
- Size: 200.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
57476a941859d3504f5caa007c9ec9d4cdd1c12a4a865f22f8c7c002d303d887
|
|
| MD5 |
e05f62cc078a1ff5aaf022523e37ca3e
|
|
| BLAKE2b-256 |
ce1d8501a159d6655c89e8616edc3a3a97a628298050d8460f3e69baf94cab04
|
File details
Details for the file so_pointjax-0.1.0-py3-none-any.whl.
File metadata
- Download URL: so_pointjax-0.1.0-py3-none-any.whl
- Upload date:
- Size: 231.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
62233cdd8219e4e19c4fe488907ddc872a1f32fc665485b87a181f792f4ef8dc
|
|
| MD5 |
7d76708053f6f6d9893fcedc1b8f6884
|
|
| BLAKE2b-256 |
c61b10f357515dedda5e6f059a11fecc7c46f2478dea95ae401d2845e083cd15
|