Skip to main content

JAX-based supernova bandflux modelling with SALT3

Project description

JAX-bandflux: Supernova SALT3 Model Fitting

PyPI version Tests Docs

Author: Samuel Alan Kossoff Leeney Homepage: https://github.com/samleeney/JAX-bandflux Documentation: https://jax-bandflux.readthedocs.io/

JAX-bandflux presents an implementation of supernova light curve modelling using JAX. The codebase offers a differentiable approach to core SNCosmo functionality implemented in JAX.

Installation

We recommend using uv for fast, reliable installation.

From PyPI (CPU)

uv pip install jax-bandflux "jax[cpu]"

GPU/CUDA 12

uv pip install jax-bandflux "jax[cuda12]" --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Install from GitHub

uv pip install git+https://github.com/samleeney/JAX-bandflux.git

Nested sampling extras

Optional dependencies for nested sampling examples:

uv pip install "jax-bandflux[nested]"

Development install

git clone https://github.com/samleeney/JAX-bandflux.git
cd JAX-bandflux
uv pip install -e ".[dev,nested,docs]"

Notes:

  • Python >= 3.10. Core deps include JAX >= 0.4.20, NumPy >= 1.24.0, Astropy, and SNCosmo; SALT3/SALT3-NIR model files are bundled with the package.
  • See the JAX installation guide for other CUDA versions.

Quickstart

Run example analogous to SNCosmo's "Using a custom fitter" example:

# Install from GitHub
uv pip install git+https://github.com/samleeney/JAX-bandflux.git

# Download and run example
wget https://raw.githubusercontent.com/samleeney/JAX-bandflux/master/examples/fmin_bfgs.py
python fmin_bfgs.py

Data format

Real light-curve data are simple ASCII tables per supernova (e.g., data/<SN>/all.phot) with required columns time/mjd, band/bandpass, flux, and fluxerr; zp defaults to 27.5 if omitted. A minimal template lives at jax_supernovae/data/example_template.phot. See the data loading guide for column details, accepted band names, and mag→flux conversion tips.

API Compatibility with SNCosmo

JAX-bandflux provides an API similar to SNCosmo's SALT3Source, with key differences for JAX compatibility:

Functional Parameter API

Parameters are passed as dictionaries to methods rather than stored as object attributes. This is a hard constraint for JIT compilation - JAX requires pure functional code where all inputs are explicit arguments.

SNCosmo approach:

source.set(z=0.5, t0=0, x0=1e-5, x1=0.5, c=0.1)
flux = source.bandflux('bessellb', time=10)

JAX-bandflux approach:

from jax_supernovae import SALT3Source

source = SALT3Source()
params = {'x0': 1e-5, 'x1': 0.5, 'c': 0.1}
flux = source.bandflux(params, 'bessellb', phase=10/(1+0.5))

This enables JIT compilation while maintaining numerical accuracy within 0.001% of SNCosmo.

Performance Optimization with Bridges

The bridges parameter allows precomputed filter integration grids, providing ~100x speedup for repeated calculations (e.g., nested sampling):

from jax_supernovae.salt3 import precompute_bandflux_bridge

# Precompute once
bridges = [precompute_bandflux_bridge(bp) for bp in bandpasses]

# Reuse thousands of times in JIT-compiled functions
@jax.jit
def likelihood(params):
    flux = source.bandflux(params, None, phases,
                          bridges=bridges,
                          band_indices=indices,
                          unique_bands=bands)
    return -0.5 * chi2

What are bridges? Precomputed wavelength grids with interpolated filter transmission values. Instead of interpolating the filter for every likelihood evaluation, you compute it once and reuse it. For nested sampling with 10,000+ evaluations, this provides a massive speedup.

Batched parameter evaluations: When JAX-bandflux is used inside GPU-based samplers and parameters are evaluated in batches, the fused bandflux kernels deliver roughly two orders of magnitude speedup per parameter set compared to SNCosmo while matching fluxes to 0.001% (see Leeney et al. 2025).

See the documentation for details.

Testing

This repository implements the JAX version of the SNCosmo bandflux function. Although the implementations are nearly identical, a minor difference exists due to the absence of a specific interpolation function in JAX, resulting in a discrepancy of approximately 0.001% in bandflux results. Tests have been provided to confirm that key functions in the SNCosmo version correspond with our JAX implementation. It is recommended to run these tests, especially following any modifications.

pytest tests/ -v

Contributing & Support

  • See CONTRIBUTING.md for how to report issues and submit PRs.
  • For help, open a GitHub issue with a minimal example and your environment (Python/JAX/JAXlib, CPU vs GPU, CUDA version).

Academic Use

If you use JAX-bandflux in your research, please cite:

@article{leeney2025jax,
  title={JAX-bandflux: differentiable supernovae SALT modelling for cosmological analysis on GPUs},
  author={Leeney, Samuel Alan Kossoff},
  journal={arXiv preprint arXiv:2504.08081},
  year={2025}
}

What is the .airules file?

The .airules file provides essential context to help language models understand and work with this codebase—particularly for new code that may not be included in model training datasets. It contains detailed information on:

  • Data structures
  • Core functions
  • Implementation constraints
  • Testing requirements

If you are using Cursor, rename this file to .cursorrules to enable automatic context integration.

Contributing and support

  • See CONTRIBUTING.md for how to report bugs, propose features, and open PRs.
  • For questions/support, please open a GitHub issue with environment details (Python/JAX version), install path (PyPI/GitHub), and a minimal reproducer.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_bandflux-0.3.11.tar.gz (16.2 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_bandflux-0.3.11-py3-none-any.whl (16.4 MB view details)

Uploaded Python 3

File details

Details for the file jax_bandflux-0.3.11.tar.gz.

File metadata

  • Download URL: jax_bandflux-0.3.11.tar.gz
  • Upload date:
  • Size: 16.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for jax_bandflux-0.3.11.tar.gz
Algorithm Hash digest
SHA256 d71a1aeb3f4de75d95dce4ad629cb365a4892c80dad904a076a8bb54fce6ef1b
MD5 7d338d6912c2e65f807d2b9557b1759b
BLAKE2b-256 955c4ed8a8a41d2fcfdc7a21fb9a994476965b4ed9ce9070763779b64c60bad8

See more details on using hashes here.

File details

Details for the file jax_bandflux-0.3.11-py3-none-any.whl.

File metadata

  • Download URL: jax_bandflux-0.3.11-py3-none-any.whl
  • Upload date:
  • Size: 16.4 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for jax_bandflux-0.3.11-py3-none-any.whl
Algorithm Hash digest
SHA256 091eba5b301161fd492cb923a0cac7bec2c587213fee2071581cde8302f88218
MD5 b5f2ed0184432707f94da4ec053d2816
BLAKE2b-256 746b6cebd0a612583ab9e23f3674d4cc5dcb2c26959a6f8dbc1e6c40efa254ce

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page