Skip to main content

JAX-based supernova bandflux modelling with SALT3

Project description

JAX-bandflux: Supernova SALT3 Model Fitting

PyPI version Tests Docs

Author: Samuel Alan Kossoff Leeney Homepage: https://github.com/samleeney/JAX-bandflux Documentation: https://jax-bandflux.readthedocs.io/

JAX-bandflux presents an implementation of supernova light curve modelling using JAX. The codebase offers a differentiable approach to core SNCosmo functionality implemented in JAX.

Installation

We recommend using uv for fast, reliable installation.

From PyPI (CPU)

uv pip install jax-bandflux "jax[cpu]"

GPU/CUDA 12

uv pip install jax-bandflux "jax[cuda12]" --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Install from GitHub

uv pip install git+https://github.com/samleeney/JAX-bandflux.git

Nested sampling extras

Optional dependencies for nested sampling examples:

uv pip install "jax-bandflux[nested]"

Development install

git clone https://github.com/samleeney/JAX-bandflux.git
cd JAX-bandflux
uv pip install -e ".[dev,nested,docs]"

Notes:

  • Python >= 3.10. Core deps include JAX >= 0.4.20, NumPy >= 1.24.0, Astropy, and SNCosmo; SALT3/SALT3-NIR model files are bundled with the package.
  • See the JAX installation guide for other CUDA versions.

Quickstart

Run example analogous to SNCosmo's "Using a custom fitter" example:

# Install from GitHub
uv pip install git+https://github.com/samleeney/JAX-bandflux.git

# Download and run example
wget https://raw.githubusercontent.com/samleeney/JAX-bandflux/master/examples/fmin_bfgs.py
python fmin_bfgs.py

Data format

Real light-curve data are simple ASCII tables per supernova (e.g., data/<SN>/all.phot) with required columns time/mjd, band/bandpass, flux, and fluxerr; zp defaults to 27.5 if omitted. A minimal template lives at jax_supernovae/data/example_template.phot. See the data loading guide for column details, accepted band names, and mag→flux conversion tips.

API Compatibility with SNCosmo

JAX-bandflux provides an API similar to SNCosmo's SALT3Source, with key differences for JAX compatibility:

Functional Parameter API

Parameters are passed as dictionaries to methods rather than stored as object attributes. This is a hard constraint for JIT compilation - JAX requires pure functional code where all inputs are explicit arguments.

SNCosmo approach:

source.set(z=0.5, t0=0, x0=1e-5, x1=0.5, c=0.1)
flux = source.bandflux('bessellb', time=10)

JAX-bandflux approach:

from jax_supernovae import SALT3Source

source = SALT3Source()
params = {'x0': 1e-5, 'x1': 0.5, 'c': 0.1}
flux = source.bandflux(params, 'bessellb', phase=10/(1+0.5))

This enables JIT compilation while maintaining numerical accuracy within 0.001% of SNCosmo.

Performance Optimization with Bridges

The bridges parameter allows precomputed filter integration grids, providing ~100x speedup for repeated calculations (e.g., nested sampling):

from jax_supernovae.salt3 import precompute_bandflux_bridge

# Precompute once
bridges = [precompute_bandflux_bridge(bp) for bp in bandpasses]

# Reuse thousands of times in JIT-compiled functions
@jax.jit
def likelihood(params):
    flux = source.bandflux(params, None, phases,
                          bridges=bridges,
                          band_indices=indices,
                          unique_bands=bands)
    return -0.5 * chi2

What are bridges? Precomputed wavelength grids with interpolated filter transmission values. Instead of interpolating the filter for every likelihood evaluation, you compute it once and reuse it. For nested sampling with 10,000+ evaluations, this provides a massive speedup.

Batched parameter evaluations: When JAX-bandflux is used inside GPU-based samplers and parameters are evaluated in batches, the fused bandflux kernels deliver roughly two orders of magnitude speedup per parameter set compared to SNCosmo while matching fluxes to 0.001% (see Leeney et al. 2025).

See the documentation for details.

Testing

This repository implements the JAX version of the SNCosmo bandflux function. Although the implementations are nearly identical, a minor difference exists due to the absence of a specific interpolation function in JAX, resulting in a discrepancy of approximately 0.001% in bandflux results. Tests have been provided to confirm that key functions in the SNCosmo version correspond with our JAX implementation. It is recommended to run these tests, especially following any modifications.

pytest tests/ -v

Contributing & Support

  • See CONTRIBUTING.md for how to report issues and submit PRs.
  • For help, open a GitHub issue with a minimal example and your environment (Python/JAX/JAXlib, CPU vs GPU, CUDA version).

Academic Use

If you use JAX-bandflux in your research, please cite:

@article{leeney2025jax,
  title={JAX-bandflux: differentiable supernovae SALT modelling for cosmological analysis on GPUs},
  author={Leeney, Samuel Alan Kossoff},
  journal={arXiv preprint arXiv:2504.08081},
  year={2025}
}

What is the .airules file?

The .airules file provides essential context to help language models understand and work with this codebase—particularly for new code that may not be included in model training datasets. It contains detailed information on:

  • Data structures
  • Core functions
  • Implementation constraints
  • Testing requirements

If you are using Cursor, rename this file to .cursorrules to enable automatic context integration.

Contributing and support

  • See CONTRIBUTING.md for how to report bugs, propose features, and open PRs.
  • For questions/support, please open a GitHub issue with environment details (Python/JAX version), install path (PyPI/GitHub), and a minimal reproducer.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_bandflux-0.3.10.tar.gz (16.2 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_bandflux-0.3.10-py3-none-any.whl (16.4 MB view details)

Uploaded Python 3

File details

Details for the file jax_bandflux-0.3.10.tar.gz.

File metadata

  • Download URL: jax_bandflux-0.3.10.tar.gz
  • Upload date:
  • Size: 16.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for jax_bandflux-0.3.10.tar.gz
Algorithm Hash digest
SHA256 af7d4b244c7a11f5a824af10cbc2b3c8771f01d405b16bb2ff82fd97748767ba
MD5 b98b3890d563996cbd40f918de041eb8
BLAKE2b-256 02ef12a86dff222b0d2adb5c2585b2f68c0ccf254665792203591c3ef5b4f1d5

See more details on using hashes here.

File details

Details for the file jax_bandflux-0.3.10-py3-none-any.whl.

File metadata

  • Download URL: jax_bandflux-0.3.10-py3-none-any.whl
  • Upload date:
  • Size: 16.4 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for jax_bandflux-0.3.10-py3-none-any.whl
Algorithm Hash digest
SHA256 5d1ce069ea903f3cdb0328faec5b14e134935c41ea3304b69262595ad043ec3d
MD5 a072fd392c2d907280258d2f03a5e245
BLAKE2b-256 c3817d8bdd0e205cecd37611ab195c17d1b6344dc26cef5030a4848db2b73c8b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page