Skip to main content

JAX-based supernova bandflux modelling with SALT3

Project description

JAX-bandflux: Supernova SALT3 Model Fitting

PyPI version Tests Docs

Author: Samuel Alan Kossoff Leeney Homepage: https://github.com/samleeney/JAX-bandflux Documentation: https://jax-bandflux.readthedocs.io/

JAX-bandflux presents an implementation of supernova light curve modelling using JAX. The codebase offers a differentiable approach to core SNCosmo functionality implemented in JAX.

Installation

From PyPI (CPU by default)

pip install jax-bandflux
pip install --upgrade "jax[cpu]"

GPU/CUDA wheels

Install the matching CUDA JAX wheel, e.g. for CUDA 12:

pip install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install jax-bandflux

or with the extra marker:

pip install "jax-bandflux[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

We do not force a CUDA dependency in install_requires; see the JAX installation guide for other CUDA versions and matching wheels.

Install from GitHub

pip install git+https://github.com/samleeney/JAX-bandflux.git

Nested sampling extras

Optional dependencies for the nested sampling examples:

pip install "jax-bandflux[nested]"

Development install

git clone https://github.com/samleeney/JAX-bandflux.git
cd JAX-bandflux
pip install -e ".[dev,nested,docs]"

Notes:

  • Python >= 3.10. Core deps include JAX >= 0.4.20, NumPy >= 1.24.0, Astropy, and SNCosmo; SALT3/SALT3-NIR model files are bundled with the package (no GitHub install needed).
  • GPU support requires installing the appropriate jax[cuda*] wheel from the JAX release index. See the JAX installation guide for other CUDA versions.

Quickstart

Run example analogous to SNCosmo's "Using a custom fitter" example:

# Install from GitHub (recommended - contains latest features)
pip install git+https://github.com/samleeney/JAX-bandflux.git

# Download and run example
wget https://raw.githubusercontent.com/samleeney/JAX-bandflux/master/examples/fmin_bfgs.py
python fmin_bfgs.py

Note: The latest features (including SALT3Source and TimeSeriesSource) are available on GitHub but not yet published to PyPI. For CUDA/GPU support, see the installation section below.

Data format

Real light-curve data are simple ASCII tables per supernova (e.g., data/<SN>/all.phot) with required columns time/mjd, band/bandpass, flux, and fluxerr; zp defaults to 27.5 if omitted. A minimal template lives at jax_supernovae/data/example_template.phot. See the data loading guide for column details, accepted band names, and mag→flux conversion tips.

API Compatibility with SNCosmo

JAX-bandflux provides an API similar to SNCosmo's SALT3Source, with key differences for JAX compatibility:

Functional Parameter API

Parameters are passed as dictionaries to methods rather than stored as object attributes. This is a hard constraint for JIT compilation - JAX requires pure functional code where all inputs are explicit arguments.

SNCosmo approach:

source.set(z=0.5, t0=0, x0=1e-5, x1=0.5, c=0.1)
flux = source.bandflux('bessellb', time=10)

JAX-bandflux approach:

from jax_supernovae import SALT3Source

source = SALT3Source()
params = {'x0': 1e-5, 'x1': 0.5, 'c': 0.1}
flux = source.bandflux(params, 'bessellb', phase=10/(1+0.5))

This enables JIT compilation while maintaining numerical accuracy within 0.001% of SNCosmo.

Performance Optimization with Bridges

The bridges parameter allows precomputed filter integration grids, providing ~100x speedup for repeated calculations (e.g., nested sampling):

from jax_supernovae.salt3 import precompute_bandflux_bridge

# Precompute once
bridges = [precompute_bandflux_bridge(bp) for bp in bandpasses]

# Reuse thousands of times in JIT-compiled functions
@jax.jit
def likelihood(params):
    flux = source.bandflux(params, None, phases,
                          bridges=bridges,
                          band_indices=indices,
                          unique_bands=bands)
    return -0.5 * chi2

What are bridges? Precomputed wavelength grids with interpolated filter transmission values. Instead of interpolating the filter for every likelihood evaluation, you compute it once and reuse it. For nested sampling with 10,000+ evaluations, this provides a massive speedup.

Batched parameter evaluations: When JAX-bandflux is used inside GPU-based samplers and parameters are evaluated in batches, the fused bandflux kernels deliver roughly two orders of magnitude speedup per parameter set compared to SNCosmo while matching fluxes to 0.001% (see Leeney et al. 2025).

See the documentation for details.

Testing

This repository implements the JAX version of the SNCosmo bandflux function. Although the implementations are nearly identical, a minor difference exists due to the absence of a specific interpolation function in JAX, resulting in a discrepancy of approximately 0.001% in bandflux results. Tests have been provided to confirm that key functions in the SNCosmo version correspond with our JAX implementation. It is recommended to run these tests, especially following any modifications.

pytest tests/ -v

Contributing & Support

  • See CONTRIBUTING.md for how to report issues and submit PRs.
  • For help, open a GitHub issue with a minimal example and your environment (Python/JAX/JAXlib, CPU vs GPU, CUDA version).

Academic Use

If you use JAX-bandflux in your research, please cite:

@article{leeney2025jax,
  title={JAX-bandflux: differentiable supernovae SALT modelling for cosmological analysis on GPUs},
  author={Leeney, Samuel Alan Kossoff},
  journal={arXiv preprint arXiv:2504.08081},
  year={2025}
}

What is the .airules file?

The .airules file provides essential context to help language models understand and work with this codebase—particularly for new code that may not be included in model training datasets. It contains detailed information on:

  • Data structures
  • Core functions
  • Implementation constraints
  • Testing requirements

If you are using Cursor, rename this file to .cursorrules to enable automatic context integration.

Contributing and support

  • See CONTRIBUTING.md for how to report bugs, propose features, and open PRs.
  • For questions/support, please open a GitHub issue with environment details (Python/JAX version), install path (PyPI/GitHub), and a minimal reproducer.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_bandflux-0.3.9.tar.gz (16.2 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_bandflux-0.3.9-py3-none-any.whl (16.4 MB view details)

Uploaded Python 3

File details

Details for the file jax_bandflux-0.3.9.tar.gz.

File metadata

  • Download URL: jax_bandflux-0.3.9.tar.gz
  • Upload date:
  • Size: 16.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for jax_bandflux-0.3.9.tar.gz
Algorithm Hash digest
SHA256 4951601e70cea251ed5835c285f1dcafffdd72423c8d30ad773a635f04918c57
MD5 20c8d4f810fb1ec3fc9c8586d2f736dd
BLAKE2b-256 da0cb5894c6bd1d2766fa10ac208f8322ed1139709c311321a16b6542045b515

See more details on using hashes here.

File details

Details for the file jax_bandflux-0.3.9-py3-none-any.whl.

File metadata

  • Download URL: jax_bandflux-0.3.9-py3-none-any.whl
  • Upload date:
  • Size: 16.4 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for jax_bandflux-0.3.9-py3-none-any.whl
Algorithm Hash digest
SHA256 8c832975406716a92619ec55249c4a0b81a60f188c6778534abdc052aab4e3c8
MD5 8b2416db5d8eab7ba4ba338e00d3b6cf
BLAKE2b-256 847f075fd0628ab0ebb899fc75a3e96da915f183c181ee29cebcd78ff457dbec

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page