JAX-based supernova bandflux modelling with SALT3
Project description
JAX-bandflux: Supernova SALT3 Model Fitting
Author: Samuel Alan Kossoff Leeney Homepage: https://github.com/samleeney/JAX-bandflux Documentation: https://jax-bandflux.readthedocs.io/
JAX-bandflux presents an implementation of supernova light curve modelling using JAX. The codebase offers a differentiable approach to core SNCosmo functionality implemented in JAX.
Installation
We recommend using uv for fast, reliable installation.
From PyPI (CPU)
uv pip install jax-bandflux "jax[cpu]"
GPU/CUDA 12
uv pip install jax-bandflux "jax[cuda12]" --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Install from GitHub
uv pip install git+https://github.com/samleeney/JAX-bandflux.git
Nested sampling extras
Optional dependencies for nested sampling examples:
uv pip install "jax-bandflux[nested]"
Development install
git clone https://github.com/samleeney/JAX-bandflux.git
cd JAX-bandflux
uv pip install -e ".[dev,nested,docs]"
Notes:
- Python >= 3.10. Core deps include JAX >= 0.4.20, NumPy >= 1.24.0, Astropy, and SNCosmo; SALT3/SALT3-NIR model files are bundled with the package.
- See the JAX installation guide for other CUDA versions.
Quickstart
Run example analogous to SNCosmo's "Using a custom fitter" example:
# Install from GitHub
uv pip install git+https://github.com/samleeney/JAX-bandflux.git
# Download and run example
wget https://raw.githubusercontent.com/samleeney/JAX-bandflux/master/examples/fmin_bfgs.py
python fmin_bfgs.py
Data format
Real light-curve data are simple ASCII tables per supernova (e.g., data/<SN>/all.phot) with required columns time/mjd, band/bandpass, flux, and fluxerr; zp defaults to 27.5 if omitted. A minimal template lives at jax_supernovae/data/example_template.phot. See the data loading guide for column details, accepted band names, and mag→flux conversion tips.
API Compatibility with SNCosmo
JAX-bandflux provides an API similar to SNCosmo's SALT3Source, with key differences for JAX compatibility:
Functional Parameter API
Parameters are passed as dictionaries to methods rather than stored as object attributes. This is a hard constraint for JIT compilation - JAX requires pure functional code where all inputs are explicit arguments.
SNCosmo approach:
source.set(z=0.5, t0=0, x0=1e-5, x1=0.5, c=0.1)
flux = source.bandflux('bessellb', time=10)
JAX-bandflux approach:
from jax_supernovae import SALT3Source
source = SALT3Source()
params = {'x0': 1e-5, 'x1': 0.5, 'c': 0.1}
flux = source.bandflux(params, 'bessellb', phase=10/(1+0.5))
This enables JIT compilation while maintaining numerical accuracy within 0.001% of SNCosmo.
Performance Optimization with Bridges
The bridges parameter allows precomputed filter integration grids, providing ~100x speedup for repeated calculations (e.g., nested sampling):
from jax_supernovae.salt3 import precompute_bandflux_bridge
# Precompute once
bridges = [precompute_bandflux_bridge(bp) for bp in bandpasses]
# Reuse thousands of times in JIT-compiled functions
@jax.jit
def likelihood(params):
flux = source.bandflux(params, None, phases,
bridges=bridges,
band_indices=indices,
unique_bands=bands)
return -0.5 * chi2
What are bridges? Precomputed wavelength grids with interpolated filter transmission values. Instead of interpolating the filter for every likelihood evaluation, you compute it once and reuse it. For nested sampling with 10,000+ evaluations, this provides a massive speedup.
Batched parameter evaluations: When JAX-bandflux is used inside GPU-based samplers and parameters are evaluated in batches, the fused bandflux kernels deliver roughly two orders of magnitude speedup per parameter set compared to SNCosmo while matching fluxes to 0.001% (see Leeney et al. 2025).
See the documentation for details.
Testing
This repository implements the JAX version of the SNCosmo bandflux function. Although the implementations are nearly identical, a minor difference exists due to the absence of a specific interpolation function in JAX, resulting in a discrepancy of approximately 0.001% in bandflux results. Tests have been provided to confirm that key functions in the SNCosmo version correspond with our JAX implementation. It is recommended to run these tests, especially following any modifications.
pytest tests/ -v
Contributing & Support
- See
CONTRIBUTING.mdfor how to report issues and submit PRs. - For help, open a GitHub issue with a minimal example and your environment (Python/JAX/JAXlib, CPU vs GPU, CUDA version).
Academic Use
If you use JAX-bandflux in your research, please cite:
@article{leeney2025jax,
title={JAX-bandflux: differentiable supernovae SALT modelling for cosmological analysis on GPUs},
author={Leeney, Samuel Alan Kossoff},
journal={arXiv preprint arXiv:2504.08081},
year={2025}
}
What is the .airules file?
The .airules file provides essential context to help language models understand and work with this codebase—particularly for new code that may not be included in model training datasets. It contains detailed information on:
- Data structures
- Core functions
- Implementation constraints
- Testing requirements
If you are using Cursor, rename this file to .cursorrules to enable automatic context integration.
Contributing and support
- See
CONTRIBUTING.mdfor how to report bugs, propose features, and open PRs. - For questions/support, please open a GitHub issue with environment details (Python/JAX version), install path (PyPI/GitHub), and a minimal reproducer.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax_bandflux-1.0.0.tar.gz.
File metadata
- Download URL: jax_bandflux-1.0.0.tar.gz
- Upload date:
- Size: 16.2 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ec06f469ddf88fc855a226d53b36cd1aed0c131be23e28670bfffe32c76999a6
|
|
| MD5 |
d7108d2abfa312c249b72177b090fe39
|
|
| BLAKE2b-256 |
b0ada230d88bfa1786ded8ec158172abea55600fda9e77f05403d2bce0512ce0
|
File details
Details for the file jax_bandflux-1.0.0-py3-none-any.whl.
File metadata
- Download URL: jax_bandflux-1.0.0-py3-none-any.whl
- Upload date:
- Size: 16.4 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a14c2876d4ab4d1418e55c67fc242199336ef46174dcfda29210cc270623bb13
|
|
| MD5 |
66fff5432a44751c832244d44431948c
|
|
| BLAKE2b-256 |
a2ccd57d2d6298cea8eb3819e29cc20b0009250a95f5cc964d247d7426427eb8
|