Skip to main content

Differentiable telescope pointing in JAX — ERFA, qpoint, and so3g.proj combined

Project description

so_pointjax

Differentiable telescope pointing in JAX. A pure-Python, GPU-ready reimplementation of the ERFA/qpoint/so3g pointing stack, fully compatible with jax.jit, jax.grad, and jax.vmap.

Architecture

so_pointjax
├── erfa     Low-level ERFA routines (time, precession, nutation, astrometry, ...)
├── qpoint   Mid-level pointing pipeline (az/el → RA/Dec, HEALPix, IERS)
└── proj     High-level API (Quat, CelestialSightLine, FocalPlane, Assembly)

Each layer builds on the one below. Use proj for most telescope work; reach into qpoint or erfa when you need finer control.

Installation

pip install so-pointjax          # from PyPI (when published)
pip install -e .                 # editable install from source

Dependencies: jax >= 0.4.0, jaxlib >= 0.4.0. Tests additionally need pytest >= 7.0 and pyerfa >= 2.0.

Quick start

import jax
import jax.numpy as jnp
from so_pointjax.proj import Quat, CelestialSightLine, FocalPlane

# Build a pointing sightline from az/el + timestamps
t  = jnp.array([1700000000.0, 1700000001.0])
az = jnp.array([1.0, 1.01])
el = jnp.array([0.8, 0.8])

csl = CelestialSightLine.az_el(t, az, el, site='act', weather='toco')

# Extract sky coordinates: (N, 4) → [lon, lat, cos2psi, sin2psi]
coords = csl.coords()

# With a focal plane of detectors
fp = FocalPlane.from_xieta(
    jnp.array([0.0, 0.01, -0.01]),
    jnp.array([0.0, 0.01,  0.01]),
)
det_coords = csl.coords(fplane=fp)   # (n_det, N, 4)

End-to-end differentiable pointing

The entire pipeline is differentiable. Compute gradients of sky coordinates with respect to any input:

def sky_lon(az, el):
    csl = CelestialSightLine.naive_az_el(
        jnp.array([1700000000.0]),
        jnp.array([az]),
        jnp.array([el]),
        site='act',
    )
    return csl.coords()[0, 0]   # RA of first sample

dra_daz, dra_del = jax.grad(sky_lon, argnums=(0, 1))(1.0, 0.8)

Quaternion algebra

The Quat class wraps JAX arrays with quaternion arithmetic, broadcasting, and operator overloading:

from so_pointjax.proj import Quat

q1 = Quat.from_lonlat(1.0, 0.5)
q2 = Quat.from_euler(2, 0.1)
q  = q1 * q2          # quaternion product
q_inv = ~q             # conjugate/inverse

# Batch operations with automatic broadcasting
q_arr = Quat.from_euler(2, jnp.linspace(0, 1, 1000))
rotated = q1 * q_arr   # (4,) x (1000, 4) -> (1000, 4)

Quat is a JAX pytree and works transparently with jit, grad, and vmap.

Submodule guides

Each submodule has its own detailed README:

  • so_pointjax.erfa -- Differentiable ERFA: time scales, precession-nutation, astrometry, ephemerides, coordinate frames, geodetic transforms, and more (~200 functions).

  • so_pointjax.qpoint -- Pointing pipeline: quaternion algebra, atmospheric/aberration corrections, az/el to RA/Dec conversion, HEALPix pixelization, IERS Bulletin A support.

  • so_pointjax.proj -- High-level API: Quat class with operator overloading, CelestialSightLine, FocalPlane, Assembly, observatory sites, and weather models.

Precision

Validated against the original C/Fortran implementations to sub-milliarcsecond accuracy:

Layer Agreement
ERFA functions Matches pyerfa
Quaternion functions Bit-identical to so3g
naive_az_el ~1e-12 (quat diff)
az_el (all weather/sites) 0.0004--0.0005 arcsec
Detector projection ~1e-12

Running tests

# All tests
python -m pytest so_pointjax/ -v

# By submodule
python -m pytest so_pointjax/erfa/tests/ -v
python -m pytest so_pointjax/qpoint/tests/ -v
python -m pytest so_pointjax/proj/tests/ -v

# Cross-validation against so3g (requires so3g)
python -m pytest so_pointjax/proj/tests/test_cross_validation.py -v -s

Performance

Key results (CPU, JIT-compiled):

  • Quaternion ops: 1.3--7x faster than so3g at N >= 100K
  • Pointing pipeline: 3--8x faster across all sizes
  • Bore x det composition: 2--3x faster for realistic focal planes
  • Gradients: ~1 us/sample (unique capability vs so3g)
python -m so_pointjax.proj.benchmarks.bench_so3g [--quick]
python -m so_pointjax.proj.benchmarks.bench_quat_array [--quick]
python -m so_pointjax.qpoint.benchmarks.bench_pointing [--quick]

Acknowledgments

This package is a JAX reimplementation of algorithms from:

  • ERFA (NumFOCUS Foundation), derived from the IAU's SOFA library
  • qpoint (Alexandra Rahlin)
  • so3g (Simons Observatory)
  • HEALPix (Gorski et al.)

See NOTICE for full license texts of the upstream libraries.

License

BSD-3-Clause. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

so_pointjax-0.1.0.tar.gz (200.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

so_pointjax-0.1.0-py3-none-any.whl (231.9 kB view details)

Uploaded Python 3

File details

Details for the file so_pointjax-0.1.0.tar.gz.

File metadata

  • Download URL: so_pointjax-0.1.0.tar.gz
  • Upload date:
  • Size: 200.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.8

File hashes

Hashes for so_pointjax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 57476a941859d3504f5caa007c9ec9d4cdd1c12a4a865f22f8c7c002d303d887
MD5 e05f62cc078a1ff5aaf022523e37ca3e
BLAKE2b-256 ce1d8501a159d6655c89e8616edc3a3a97a628298050d8460f3e69baf94cab04

See more details on using hashes here.

File details

Details for the file so_pointjax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: so_pointjax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 231.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.8

File hashes

Hashes for so_pointjax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 62233cdd8219e4e19c4fe488907ddc872a1f32fc665485b87a181f792f4ef8dc
MD5 7d76708053f6f6d9893fcedc1b8f6884
BLAKE2b-256 c61b10f357515dedda5e6f059a11fecc7c46f2478dea95ae401d2845e083cd15

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page